Spaces:
Running
Running
samuelemarro
commited on
Commit
·
3cad23b
0
Parent(s):
Initial upload to test HF Spaces.
Browse files- .gitignore +4 -0
- app.py +276 -0
- executor.py +25 -0
- flow.py +183 -0
- negotiator.py +116 -0
- programmer.py +125 -0
- querier.py +158 -0
- requirements.txt +28 -0
- responder.py +87 -0
- toolformers/base.py +332 -0
- toolformers/camel.py +86 -0
- toolformers/gemini.py +100 -0
- toolformers/huggingface_agent.py +39 -0
- toolformers/langchain_agent.py +73 -0
- toolformers/sambanova/__init__.py +1 -0
- toolformers/sambanova/api_gateway.py +146 -0
- toolformers/sambanova/core.py +46 -0
- toolformers/sambanova/function_calling.py +312 -0
- toolformers/sambanova/sambanova_langchain.py +827 -0
- toolformers/sambanova/utils.py +32 -0
- utils.py +80 -0
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
venv/**
|
2 |
+
venv2/**
|
3 |
+
*.pyc
|
4 |
+
.env
|
app.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
from collections import UserList
|
6 |
+
|
7 |
+
from flow import full_flow
|
8 |
+
|
9 |
+
schema = {
|
10 |
+
"input": {
|
11 |
+
"type": "object",
|
12 |
+
"properties": {
|
13 |
+
"location": {
|
14 |
+
"type": "string",
|
15 |
+
"description": "The name of the location for which the weather forecast is requested."
|
16 |
+
},
|
17 |
+
"date": {
|
18 |
+
"type": "string",
|
19 |
+
"format": "date",
|
20 |
+
"description": "The date for which the weather forecast is requested, in YYYY-MM-DD format."
|
21 |
+
}
|
22 |
+
},
|
23 |
+
"required": [
|
24 |
+
"location",
|
25 |
+
"date"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
"output": {
|
29 |
+
"type": "object",
|
30 |
+
"properties": {
|
31 |
+
"temperature": {
|
32 |
+
"type": "number",
|
33 |
+
"description": "The forecasted temperature in degrees Celsius."
|
34 |
+
},
|
35 |
+
"condition": {
|
36 |
+
"type": "string",
|
37 |
+
"description": "A brief description of the weather condition (e.g., sunny, cloudy, rainy)."
|
38 |
+
},
|
39 |
+
"humidity": {
|
40 |
+
"type": "number",
|
41 |
+
"description": "The forecasted humidity percentage."
|
42 |
+
},
|
43 |
+
"wind_speed": {
|
44 |
+
"type": "number",
|
45 |
+
"description": "The forecasted wind speed in kilometers per hour."
|
46 |
+
}
|
47 |
+
},
|
48 |
+
"required": [
|
49 |
+
"temperature",
|
50 |
+
"condition",
|
51 |
+
"humidity",
|
52 |
+
"wind_speed"
|
53 |
+
]
|
54 |
+
},
|
55 |
+
"description": "Alice requests a weather forecast for a specific location and date from Bob's weather service.",
|
56 |
+
"examples": [
|
57 |
+
{
|
58 |
+
"location": "New York",
|
59 |
+
"date": "2023-10-15"
|
60 |
+
},
|
61 |
+
{
|
62 |
+
"location": "London",
|
63 |
+
"date": "2023-11-01"
|
64 |
+
}
|
65 |
+
],
|
66 |
+
"tools": [
|
67 |
+
{
|
68 |
+
"name": "WeatherForecastAPI",
|
69 |
+
"description": "An API that provides weather forecasts for a given location and date.",
|
70 |
+
"input": {
|
71 |
+
"type": "object",
|
72 |
+
"properties": {
|
73 |
+
"location": {
|
74 |
+
"type": "string",
|
75 |
+
"description": "The name of the location for which the weather forecast is requested."
|
76 |
+
},
|
77 |
+
"date": {
|
78 |
+
"type": "string",
|
79 |
+
"format": "date",
|
80 |
+
"description": "The date for which the weather forecast is requested, in YYYY-MM-DD format."
|
81 |
+
}
|
82 |
+
},
|
83 |
+
"required": [
|
84 |
+
"location",
|
85 |
+
"date"
|
86 |
+
]
|
87 |
+
},
|
88 |
+
"output": {
|
89 |
+
"type": "object",
|
90 |
+
"properties": {
|
91 |
+
"temperature": {
|
92 |
+
"type": "number",
|
93 |
+
"description": "The forecasted temperature in degrees Celsius."
|
94 |
+
},
|
95 |
+
"condition": {
|
96 |
+
"type": "string",
|
97 |
+
"description": "A brief description of the weather condition (e.g., sunny, cloudy, rainy)."
|
98 |
+
},
|
99 |
+
"humidity": {
|
100 |
+
"type": "number",
|
101 |
+
"description": "The forecasted humidity percentage."
|
102 |
+
},
|
103 |
+
"wind_speed": {
|
104 |
+
"type": "number",
|
105 |
+
"description": "The forecasted wind speed in kilometers per hour."
|
106 |
+
}
|
107 |
+
},
|
108 |
+
"required": [
|
109 |
+
"temperature",
|
110 |
+
"condition",
|
111 |
+
"humidity",
|
112 |
+
"wind_speed"
|
113 |
+
]
|
114 |
+
},
|
115 |
+
"dummy_outputs": [
|
116 |
+
{
|
117 |
+
"temperature": 18,
|
118 |
+
"condition": "Sunny",
|
119 |
+
"humidity": 55,
|
120 |
+
"wind_speed": 10
|
121 |
+
},
|
122 |
+
{
|
123 |
+
"temperature": 12,
|
124 |
+
"condition": "Cloudy",
|
125 |
+
"humidity": 80,
|
126 |
+
"wind_speed": 15
|
127 |
+
}
|
128 |
+
]
|
129 |
+
}
|
130 |
+
]
|
131 |
+
}
|
132 |
+
|
133 |
+
SCHEMAS = {
|
134 |
+
"weather_forecast": schema,
|
135 |
+
"other": { "input": "PIPPO"}
|
136 |
+
}
|
137 |
+
|
138 |
+
def parse_raw_messages(messages_raw):
|
139 |
+
messages_clean = []
|
140 |
+
messages_agora = []
|
141 |
+
|
142 |
+
for message in messages_raw:
|
143 |
+
role = message['role']
|
144 |
+
message_without_role = dict(message)
|
145 |
+
del message_without_role['role']
|
146 |
+
|
147 |
+
messages_agora.append({
|
148 |
+
'role': role,
|
149 |
+
'content': '```\n' + json.dumps(message_without_role, indent=2) + '\n```'
|
150 |
+
})
|
151 |
+
|
152 |
+
if message.get('status') == 'error':
|
153 |
+
messages_clean.append({
|
154 |
+
'role': role,
|
155 |
+
'content': f"Error: {message['message']}"
|
156 |
+
})
|
157 |
+
else:
|
158 |
+
messages_clean.append({
|
159 |
+
'role': role,
|
160 |
+
'content': message['body']
|
161 |
+
})
|
162 |
+
|
163 |
+
return messages_clean, messages_agora
|
164 |
+
|
165 |
+
def main():
|
166 |
+
with gr.Blocks() as demo:
|
167 |
+
gr.Markdown("### Agora Demo")
|
168 |
+
gr.Markdown("We will create a new Agora channel and offer it to Alice as a tool.")
|
169 |
+
|
170 |
+
chosen_task = gr.Dropdown(choices=list(SCHEMAS.keys()), label="Schema", value="weather_forecast")
|
171 |
+
custom_task = gr.Checkbox(label="Custom Task")
|
172 |
+
|
173 |
+
STATE_TRACKER = {}
|
174 |
+
|
175 |
+
@gr.render(inputs=[chosen_task, custom_task])
|
176 |
+
def render(chosen_task, custom_task):
|
177 |
+
if STATE_TRACKER.get('chosen_task') != chosen_task:
|
178 |
+
STATE_TRACKER['chosen_task'] = chosen_task
|
179 |
+
for k, v in SCHEMAS[chosen_task].items():
|
180 |
+
if isinstance(v, str):
|
181 |
+
STATE_TRACKER[k] = v
|
182 |
+
else:
|
183 |
+
STATE_TRACKER[k] = json.dumps(v, indent=2)
|
184 |
+
|
185 |
+
if custom_task:
|
186 |
+
gr.Text(label="Description", value=STATE_TRACKER["description"], interactive=True).change(lambda x: STATE_TRACKER.update({'description': x}))
|
187 |
+
gr.TextArea(label="Input Schema", value=STATE_TRACKER["input"], interactive=True).change(lambda x: STATE_TRACKER.update({'input': x}))
|
188 |
+
gr.TextArea(label="Output Schema", value=STATE_TRACKER["output"], interactive=True).change(lambda x: STATE_TRACKER.update({'output': x}))
|
189 |
+
gr.TextArea(label="Tools", value=STATE_TRACKER["tools"], interactive=True).change(lambda x: STATE_TRACKER.update({'tools': x}))
|
190 |
+
gr.TextArea(label="Examples", value=STATE_TRACKER["examples"], interactive=True).change(lambda x: STATE_TRACKER.update({'examples': x}))
|
191 |
+
|
192 |
+
model_options = [
|
193 |
+
('GPT 4o (Camel AI)', 'gpt-4o'),
|
194 |
+
('GPT 4o-mini (Camel AI)', 'gpt-4o-mini'),
|
195 |
+
('Claude 3 Sonnet (LangChain)', 'claude-3-sonnet'),
|
196 |
+
('Gemini 1.5 Pro (Google GenAI)', 'gemini-1.5-pro'),
|
197 |
+
('Llama3 405B (Sambanova + LangChain)', 'llama3-405b')
|
198 |
+
]
|
199 |
+
|
200 |
+
fallback_image = ''
|
201 |
+
|
202 |
+
images = {
|
203 |
+
'gpt-4o': 'https://uxwing.com/wp-content/themes/uxwing/download/brands-and-social-media/chatgpt-icon.png',
|
204 |
+
'gpt-4o-mini': 'https://uxwing.com/wp-content/themes/uxwing/download/brands-and-social-media/chatgpt-icon.png',
|
205 |
+
'claude-3-sonnet': 'https://play-lh.googleusercontent.com/4S1nfdKsH_1tJodkHrBHimqlCTE6qx6z22zpMyPaMc_Rlr1EdSFDI1I6UEVMnokG5zI',
|
206 |
+
'gemini-1.5-pro': 'https://uxwing.com/wp-content/themes/uxwing/download/brands-and-social-media/google-gemini-icon.png',
|
207 |
+
'llama3-405b': 'https://www.designstub.com/png-resources/wp-content/uploads/2023/03/meta-icon-social-media-flat-graphic-vector-3-novem.png'
|
208 |
+
}
|
209 |
+
|
210 |
+
with gr.Row(equal_height=True):
|
211 |
+
with gr.Column(scale=1):
|
212 |
+
alice_model_dd = gr.Dropdown(label="Alice Model", choices=model_options, value="gpt-4o")
|
213 |
+
with gr.Column(scale=1):
|
214 |
+
bob_model_dd = gr.Dropdown(label="Bob Model", choices=model_options, value="gpt-4o")
|
215 |
+
|
216 |
+
button = gr.Button('Start', elem_id='start_button')
|
217 |
+
gr.Markdown('### Natural Language')
|
218 |
+
|
219 |
+
@gr.render(inputs=[alice_model_dd, bob_model_dd])
|
220 |
+
def render_with_images(alice_model, bob_model):
|
221 |
+
avatar_images = [images.get(alice_model, fallback_image), images.get(bob_model, fallback_image)]
|
222 |
+
chatbot_nl = gr.Chatbot(type="messages", avatar_images=avatar_images)
|
223 |
+
|
224 |
+
with gr.Accordion(label="Raw Messages", open=False):
|
225 |
+
chatbot_nl_raw = gr.Chatbot(type="messages", avatar_images=avatar_images)
|
226 |
+
|
227 |
+
gr.Markdown('### Negotiation')
|
228 |
+
chatbot_negotiation = gr.Chatbot(type="messages", avatar_images=avatar_images)
|
229 |
+
|
230 |
+
gr.Markdown('### Protocol')
|
231 |
+
protocol_result = gr.TextArea(interactive=False, label="Protocol")
|
232 |
+
|
233 |
+
gr.Markdown('### Implementation')
|
234 |
+
with gr.Row():
|
235 |
+
with gr.Column(scale=1):
|
236 |
+
alice_implementation = gr.TextArea(interactive=False, label="Alice Implementation")
|
237 |
+
with gr.Column(scale=1):
|
238 |
+
bob_implementation = gr.TextArea(interactive=False, label="Bob Implementation")
|
239 |
+
|
240 |
+
gr.Markdown('### Structured Communication')
|
241 |
+
structured_communication = gr.Chatbot(type="messages", avatar_images=avatar_images)
|
242 |
+
|
243 |
+
with gr.Accordion(label="Raw Messages", open=False):
|
244 |
+
structured_communication_raw = gr.Chatbot(type="messages", avatar_images=avatar_images)
|
245 |
+
|
246 |
+
def respond(chosen_task, custom_task, alice_model, bob_model):
|
247 |
+
yield gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), \
|
248 |
+
None, None, None, None, None, None, None, None
|
249 |
+
|
250 |
+
if custom_task:
|
251 |
+
schema = dict(STATE_TRACKER)
|
252 |
+
for k, v in schema.items():
|
253 |
+
if isinstance(v, str):
|
254 |
+
try:
|
255 |
+
schema[k] = json.loads(v)
|
256 |
+
except:
|
257 |
+
pass
|
258 |
+
else:
|
259 |
+
schema = SCHEMAS[chosen_task]
|
260 |
+
|
261 |
+
for nl_messages_raw, negotiation_messages, structured_messages_raw, protocol, alice_implementation, bob_implementation in full_flow(schema, alice_model, bob_model):
|
262 |
+
nl_messages_clean, nl_messages_agora = parse_raw_messages(nl_messages_raw)
|
263 |
+
structured_messages_clean, structured_messages_agora = parse_raw_messages(structured_messages_raw)
|
264 |
+
|
265 |
+
yield gr.update(), gr.update(), gr.update(), nl_messages_clean, nl_messages_agora, negotiation_messages, structured_messages_clean, structured_messages_agora, protocol, alice_implementation, bob_implementation
|
266 |
+
|
267 |
+
#yield from full_flow(schema, alice_model, bob_model)
|
268 |
+
yield gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
269 |
+
|
270 |
+
button.click(respond, [chosen_task, custom_task, alice_model_dd, bob_model_dd], [button, alice_model_dd, bob_model_dd, chatbot_nl, chatbot_nl_raw, chatbot_negotiation, structured_communication, structured_communication_raw, protocol_result, alice_implementation, bob_implementation])
|
271 |
+
|
272 |
+
demo.launch()
|
273 |
+
|
274 |
+
|
275 |
+
if __name__ == '__main__':
|
276 |
+
main()
|
executor.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
import importlib
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
from toolformers.base import Tool
|
6 |
+
|
7 |
+
class Executor:
|
8 |
+
@abstractmethod
|
9 |
+
def run_routine(self, protocol_id, task_data, tools):
|
10 |
+
pass
|
11 |
+
|
12 |
+
class UnsafeExecutor(Executor):
|
13 |
+
def run_routine(self, protocol_id, code, task_data, tools : List[Tool]):
|
14 |
+
protocol_id = protocol_id.replace('-', '_').replace('.', '_').replace('/', '_')
|
15 |
+
# TODO: This should be done in a safe, containerized environment
|
16 |
+
spec = importlib.util.spec_from_loader(protocol_id, loader=None)
|
17 |
+
loaded_module = importlib.util.module_from_spec(spec)
|
18 |
+
|
19 |
+
#spec.loader.exec_module(loaded_module)
|
20 |
+
exec(code, loaded_module.__dict__)
|
21 |
+
|
22 |
+
for tool in tools:
|
23 |
+
loaded_module.__dict__[tool.name] = tool.as_executable_function()
|
24 |
+
|
25 |
+
return loaded_module.run(task_data)
|
flow.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dotenv
|
2 |
+
dotenv.load_dotenv()
|
3 |
+
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
import threading
|
8 |
+
import time
|
9 |
+
|
10 |
+
from toolformers.base import Tool, parameter_from_openai_api, StringParameter
|
11 |
+
from toolformers.base import Toolformer
|
12 |
+
from toolformers.camel import make_openai_toolformer
|
13 |
+
from toolformers.langchain_agent import LangChainAnthropicToolformer
|
14 |
+
from toolformers.sambanova import SambanovaToolformer
|
15 |
+
from toolformers.gemini import GeminiToolformer
|
16 |
+
|
17 |
+
from querier import Querier
|
18 |
+
from responder import Responder
|
19 |
+
from negotiator import SenderNegotiator, ReceiverNegotiator
|
20 |
+
from programmer import SenderProgrammer, ReceiverProgrammer
|
21 |
+
from executor import UnsafeExecutor
|
22 |
+
from utils import compute_hash
|
23 |
+
|
24 |
+
|
25 |
+
def create_toolformer(model_name) -> Toolformer:
|
26 |
+
if model_name in ['gpt-4o', 'gpt-4o-mini']:
|
27 |
+
return make_openai_toolformer(model_name)
|
28 |
+
elif 'claude' in model_name:
|
29 |
+
return LangChainAnthropicToolformer(model_name, os.environ.get('ANTHROPIC_API_KEY'))
|
30 |
+
elif model_name in ['llama3-405b']:
|
31 |
+
return SambanovaToolformer(model_name)
|
32 |
+
elif model_name in ['gemini-1.5-pro']:
|
33 |
+
return GeminiToolformer(model_name)
|
34 |
+
else:
|
35 |
+
raise ValueError(f"Unknown model name: {model_name}")
|
36 |
+
|
37 |
+
def full_flow(schema, alice_model, bob_model):
|
38 |
+
NL_MESSAGES = []
|
39 |
+
NEGOTIATION_MESSAGES = []
|
40 |
+
STRUCTURED_MESSAGES = []
|
41 |
+
ARTIFACTS = {}
|
42 |
+
|
43 |
+
toolformer_alice = create_toolformer(alice_model)
|
44 |
+
toolformer_bob = create_toolformer(bob_model)
|
45 |
+
|
46 |
+
querier = Querier(toolformer_alice)
|
47 |
+
responder = Responder(toolformer_bob)
|
48 |
+
|
49 |
+
tools = []
|
50 |
+
|
51 |
+
for tool_schema in schema['tools']:
|
52 |
+
parameters = [parameter_from_openai_api(name, schema, name in tool_schema['input']['required']) for name, schema in tool_schema['input']['properties'].items()]
|
53 |
+
|
54 |
+
def tool_fn(*args, **kwargs):
|
55 |
+
print(f'Bob tool {tool_schema["name"]} called with args {args} and kwargs {kwargs}')
|
56 |
+
return random.choice(tool_schema['dummy_outputs'])
|
57 |
+
|
58 |
+
tool = Tool(tool_schema['name'], tool_schema['description'], parameters, tool_fn, tool_schema['output'])
|
59 |
+
tools.append(tool)
|
60 |
+
|
61 |
+
def nl_callback_fn(query):
|
62 |
+
print(query)
|
63 |
+
NL_MESSAGES.append({
|
64 |
+
'role': 'assistant',
|
65 |
+
#'content': query['body'],
|
66 |
+
'body': query['body'],
|
67 |
+
'protocolHash': None
|
68 |
+
})
|
69 |
+
|
70 |
+
response = responder.reply_to_query(query['body'], query['protocolHash'], tools, '')
|
71 |
+
|
72 |
+
NL_MESSAGES.append({
|
73 |
+
'role': 'user',
|
74 |
+
#'content': response['body']
|
75 |
+
'status': 'success',
|
76 |
+
'body': response['body']
|
77 |
+
})
|
78 |
+
|
79 |
+
return response
|
80 |
+
|
81 |
+
negotiator_sender = SenderNegotiator(toolformer_alice)
|
82 |
+
negotiator_receiver = ReceiverNegotiator(toolformer_bob, tools, '')
|
83 |
+
|
84 |
+
def negotiation_callback_fn(query):
|
85 |
+
print(query)
|
86 |
+
NEGOTIATION_MESSAGES.append({
|
87 |
+
'role': 'assistant',
|
88 |
+
'content': query
|
89 |
+
})
|
90 |
+
|
91 |
+
response = negotiator_receiver.handle_negotiation(query)
|
92 |
+
|
93 |
+
NEGOTIATION_MESSAGES.append({
|
94 |
+
'role': 'user',
|
95 |
+
'content': response
|
96 |
+
})
|
97 |
+
|
98 |
+
#print('CURRENT NEGOTIATION MESSAGES:', len(NEGOTIATION_MESSAGES))
|
99 |
+
|
100 |
+
return response
|
101 |
+
|
102 |
+
def final_message_callback_fn(query):
|
103 |
+
NEGOTIATION_MESSAGES.append({
|
104 |
+
'role': 'assistant',
|
105 |
+
'content': query
|
106 |
+
})
|
107 |
+
|
108 |
+
sender_programmer = SenderProgrammer(toolformer_alice)
|
109 |
+
receiver_programmer = ReceiverProgrammer(toolformer_bob)
|
110 |
+
|
111 |
+
executor = UnsafeExecutor()
|
112 |
+
|
113 |
+
def structured_callback_fn(query):
|
114 |
+
STRUCTURED_MESSAGES.append({
|
115 |
+
'role': 'assistant',
|
116 |
+
#'content': query
|
117 |
+
'body': json.dumps(query) if isinstance(query, dict) else query,
|
118 |
+
'protocolHash': ARTIFACTS['protocol']['hash'],
|
119 |
+
'protocolSources': ['https://...']
|
120 |
+
})
|
121 |
+
|
122 |
+
try:
|
123 |
+
response = executor.run_routine(ARTIFACTS['protocol']['hash'], ARTIFACTS['implementation_receiver'], query, tools)
|
124 |
+
except Exception as e:
|
125 |
+
STRUCTURED_MESSAGES.append({
|
126 |
+
'role': 'user',
|
127 |
+
'status': 'error',
|
128 |
+
'message': str(e)
|
129 |
+
})
|
130 |
+
return 'Error'
|
131 |
+
|
132 |
+
STRUCTURED_MESSAGES.append({
|
133 |
+
'role': 'user',
|
134 |
+
#'content': response
|
135 |
+
'status': 'success',
|
136 |
+
'body': json.dumps(response) if isinstance(response, dict) else response
|
137 |
+
})
|
138 |
+
|
139 |
+
return response
|
140 |
+
|
141 |
+
def flow():
|
142 |
+
task_data = random.choice(schema['examples'])
|
143 |
+
querier.send_query_without_protocol(schema, task_data, nl_callback_fn)
|
144 |
+
|
145 |
+
#time.sleep(1)
|
146 |
+
|
147 |
+
res = negotiator_sender.negotiate_protocol_for_task(schema, negotiation_callback_fn, final_message_callback_fn=final_message_callback_fn)
|
148 |
+
protocol_hash = compute_hash(res['protocol'])
|
149 |
+
res['hash'] = protocol_hash
|
150 |
+
ARTIFACTS['protocol'] = res
|
151 |
+
|
152 |
+
protocol_document = res['protocol']
|
153 |
+
|
154 |
+
implementation_sender = sender_programmer.write_routine_for_task(schema, protocol_document)
|
155 |
+
|
156 |
+
ARTIFACTS['implementation_sender'] = implementation_sender
|
157 |
+
|
158 |
+
implementation_receiver = receiver_programmer.write_routine_for_tools(tools, protocol_document, '')
|
159 |
+
|
160 |
+
ARTIFACTS['implementation_receiver'] = implementation_receiver
|
161 |
+
send_tool = Tool('send_to_server', 'Send to server', StringParameter('query', 'The query', True), structured_callback_fn)
|
162 |
+
|
163 |
+
try:
|
164 |
+
executor.run_routine(protocol_hash, implementation_sender, task_data, [send_tool])
|
165 |
+
except Exception as e:
|
166 |
+
STRUCTURED_MESSAGES.append({
|
167 |
+
'role': 'assistant',
|
168 |
+
'status': 'success',
|
169 |
+
'error': str(e)
|
170 |
+
})
|
171 |
+
|
172 |
+
def get_info():
|
173 |
+
return NL_MESSAGES, NEGOTIATION_MESSAGES, STRUCTURED_MESSAGES, ARTIFACTS.get('protocol', {}).get('protocol', ''), \
|
174 |
+
ARTIFACTS.get('implementation_sender', ''), ARTIFACTS.get('implementation_receiver', '')
|
175 |
+
|
176 |
+
thread = threading.Thread(
|
177 |
+
target = lambda: flow()
|
178 |
+
)
|
179 |
+
thread.start()
|
180 |
+
while thread.is_alive():
|
181 |
+
yield get_info()
|
182 |
+
time.sleep(0.2)
|
183 |
+
yield get_info()
|
negotiator.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import uuid
|
3 |
+
|
4 |
+
from toolformers.base import Toolformer
|
5 |
+
from utils import extract_substring
|
6 |
+
|
7 |
+
NEGOTIATION_RULES = '''
|
8 |
+
Here are some rules (that should also be explained to the other GPT):
|
9 |
+
- You can assume that the protocol has a sender and a receiver. Do not worry about how the messages will be delivered, focus only on the content of the messages.
|
10 |
+
- Keep the protocol short and simple. It should be easy to understand and implement.
|
11 |
+
- The protocol must specify the exact format of what is sent and received. Do not leave it open to interpretation.
|
12 |
+
- The implementation will be written by a programmer that does not have access to the negotiation process, so make sure the protocol is clear and unambiguous.
|
13 |
+
- The implementation will receive a string and return a string, so structure your protocol accordingly.
|
14 |
+
- The other party might have a different internal data schema or set of tools, so make sure that the protocol is flexible enough to accommodate that.
|
15 |
+
- There will only be one message sent by the sender and one message sent by the receiver. Design the protocol accordingly.
|
16 |
+
- Keep the negotiation short: no need to repeat the same things over and over.
|
17 |
+
- If the other party has proposed a protocol and you're good with it, there's no reason to keep negotiating or to repeat the protocol to the other party.
|
18 |
+
- Do not restate parts of the protocols that have already been agreed upon.
|
19 |
+
And remember: keep the protocol as simple and unequivocal as necessary. The programmer that will implement the protocol can code, but they are not a mind reader.
|
20 |
+
'''
|
21 |
+
|
22 |
+
TASK_NEGOTIATOR_PROMPT = f'''
|
23 |
+
You are ProtocolNegotiatorGPT. Your task is to negotiate a protocol that can be used to query a service.
|
24 |
+
You will receive a JSON schema of the task that the service must perform. Negotiate with the service to determine a protocol that can be used to query it.
|
25 |
+
To do so, you will chat with another GPT (role: user) that will negotiate on behalf of the service.
|
26 |
+
{NEGOTIATION_RULES}
|
27 |
+
Once you are ready to save the protocol, reply wrapping the final version of the protocol, as agreed in your negotiation, between the tags <FINALPROTOCOL> and </FINALPROTOCOL>.
|
28 |
+
Within the body of the tag, add the tags <NAME></NAME> and <DESCRIPTION></DESCRIPTION> to specify the name and description of the protocol.
|
29 |
+
|
30 |
+
Remember that the <FINALPROTOCOL></FINALPROTOCOL> tags should also contain the protocol itself. Nothing outside such tags will be stored.
|
31 |
+
'''
|
32 |
+
|
33 |
+
class SenderNegotiator:
|
34 |
+
def __init__(self, toolformer : Toolformer):
|
35 |
+
self.toolformer = toolformer
|
36 |
+
|
37 |
+
def negotiate_protocol_for_task(self, task_schema, callback_fn, final_message_callback_fn=None):
|
38 |
+
found_protocol = None
|
39 |
+
|
40 |
+
prompt = TASK_NEGOTIATOR_PROMPT + '\nThe JSON schema of the task is the following:\n\n' + json.dumps(task_schema, indent=2)
|
41 |
+
|
42 |
+
conversation = self.toolformer.new_conversation(prompt, [], category='negotiation')
|
43 |
+
|
44 |
+
other_message = 'Hello! How may I help you?'
|
45 |
+
conversation_id = None
|
46 |
+
|
47 |
+
for i in range(10):
|
48 |
+
print('===NegotiatorGPT===')
|
49 |
+
message = conversation.chat(other_message, print_output=True)
|
50 |
+
|
51 |
+
print('Checking if we can extract from:', message)
|
52 |
+
print('---------')
|
53 |
+
protocol = extract_substring(message, '<FINALPROTOCOL>', '</FINALPROTOCOL>')
|
54 |
+
|
55 |
+
if protocol is None:
|
56 |
+
print('Could not extract')
|
57 |
+
other_message = callback_fn(message)
|
58 |
+
print()
|
59 |
+
print('===Other GPT===')
|
60 |
+
print(other_message)
|
61 |
+
print()
|
62 |
+
else:
|
63 |
+
if final_message_callback_fn:
|
64 |
+
rest_of_message = message.split('<FINALPROTOCOL>')[0]
|
65 |
+
final_message_callback_fn(rest_of_message)
|
66 |
+
|
67 |
+
name = extract_substring(protocol, '<NAME>', '</NAME>')
|
68 |
+
description = extract_substring(protocol, '<DESCRIPTION>', '</DESCRIPTION>')
|
69 |
+
|
70 |
+
if name is None:
|
71 |
+
name = 'Unnamed protocol'
|
72 |
+
if description is None:
|
73 |
+
description = 'No description provided'
|
74 |
+
|
75 |
+
found_protocol = {
|
76 |
+
'name': name,
|
77 |
+
'description': description,
|
78 |
+
'protocol': protocol
|
79 |
+
}
|
80 |
+
break
|
81 |
+
|
82 |
+
return found_protocol
|
83 |
+
|
84 |
+
TOOLS_NEGOTIATOR_PROMPT = f'''
|
85 |
+
You are ProtocolNegotiatorGPT. You are negotiating a protocol on behalf of a web service that can perform a task.
|
86 |
+
The other party is a GPT that is negotiating on behalf of the user. Your goal is to negotiate a protocol that is simple and clear, \
|
87 |
+
but also expressive enough to allow the service to perform the task. A protocol is sufficiently expressive if you could write code \
|
88 |
+
that, given the query formatted according to the protocol and the tools at the service's disposal, can parse the query according to \
|
89 |
+
the protocol's specification, perform the task (if any) and send a reply.
|
90 |
+
{NEGOTIATION_RULES}
|
91 |
+
You will receive a list of tools that are available to the programmer that will implement the protocol.
|
92 |
+
When you are okay with the protocol, don't further repeat everything, just tell to the other party that you are done.
|
93 |
+
'''
|
94 |
+
|
95 |
+
class ReceiverNegotiator:
|
96 |
+
def __init__(self, toolformer : Toolformer, tools, additional_info):
|
97 |
+
prompt = TOOLS_NEGOTIATOR_PROMPT
|
98 |
+
|
99 |
+
prompt += '\n\n' + additional_info
|
100 |
+
|
101 |
+
prompt += '\n\nThe tools that the implementer will have access to are:\n\n'
|
102 |
+
|
103 |
+
if len(tools) == 0:
|
104 |
+
prompt += 'No additional tools provided'
|
105 |
+
else:
|
106 |
+
for tool in tools:
|
107 |
+
prompt += tool.as_documented_python() + '\n\n'
|
108 |
+
|
109 |
+
print('Prompt:', prompt)
|
110 |
+
|
111 |
+
self.conversation = toolformer.new_conversation(prompt, tools, category='negotiation')
|
112 |
+
|
113 |
+
def handle_negotiation(self, message):
|
114 |
+
reply = self.conversation.chat(message, print_output=True)
|
115 |
+
|
116 |
+
return reply
|
programmer.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The programmer creates implementations depending on a protocol specification
|
2 |
+
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
|
6 |
+
from toolformers.base import Toolformer
|
7 |
+
from utils import extract_substring
|
8 |
+
|
9 |
+
TASK_PROGRAMMER_PROMPT = '''
|
10 |
+
You are ProtocolProgrammerGPT. You will act as an intermediate between a machine (that has a certain input and output schema in JSON) \
|
11 |
+
and a remote server that can perform a task following a certain protocol. Your task is to write a routine that takes some task data \
|
12 |
+
(which follows the input schema), sends query in a format defined by the protocol, parses it and returns the output according to the output schema so that \
|
13 |
+
the machine can use it.
|
14 |
+
The routine is a Python file that contains a function "send_query". send_query takes a single argument, "task_data", which is a dictionary, and must return \
|
15 |
+
a dictionary, which is the response to the query formatted according to the output schema.
|
16 |
+
In order to communicate with the remote server, you can use the function "send_to_server" that is already available in the environment.
|
17 |
+
send_to_server takes a single argument, "query" (which is a string formatted according to the protocol), and returns a string (again formatted according \
|
18 |
+
to the protocol). Do not worry about managing communication, everything is already set up for you. Just focus on preparing the right query.
|
19 |
+
|
20 |
+
Rules:
|
21 |
+
- The implementation must be written in Python.
|
22 |
+
- You can define any number of helper functions and import any libraries that are part of the Python standard library.
|
23 |
+
- Do not import libraries that are not part of the Python standard library.
|
24 |
+
- send_to_server will be already available in the environment. There is no need to import it.
|
25 |
+
- Your task is to prepare the query, send it and parse the response.
|
26 |
+
- Remember to import standard libraries if you need them.
|
27 |
+
- If there is an unexpected error that is not covered by the protocol, throw an exception.\
|
28 |
+
If instead the protocol specifies how to handle the error, return the response according to the protocol's specification.
|
29 |
+
- Do not execute anything (aside from library imports) when the file itself is loaded. I will personally import the file and call the send_query function with the task data.
|
30 |
+
Begin by thinking about the implementation and how you would structure the code. \
|
31 |
+
Then, write your implementation by writing a code block that contains the tags <IMPLEMENTATION> and </IMPLEMENTATION>. For example:
|
32 |
+
```python
|
33 |
+
<IMPLEMENTATION>
|
34 |
+
|
35 |
+
def send_query(task_data):
|
36 |
+
...
|
37 |
+
|
38 |
+
</IMPLEMENTATION>
|
39 |
+
'''
|
40 |
+
|
41 |
+
class SenderProgrammer:
|
42 |
+
def __init__(self, toolformer : Toolformer):
|
43 |
+
self.toolformer = toolformer
|
44 |
+
|
45 |
+
def write_routine_for_task(self, task_schema, protocol_document):
|
46 |
+
conversation = self.toolformer.new_conversation(TASK_PROGRAMMER_PROMPT, [], category='programming')
|
47 |
+
message = 'JSON schema:\n\n' + json.dumps(task_schema) + '\n\n' + 'Protocol document:\n\n' + protocol_document
|
48 |
+
|
49 |
+
for i in range(5):
|
50 |
+
reply = conversation.chat(message, print_output=True)
|
51 |
+
|
52 |
+
implementation = extract_substring(reply, '<IMPLEMENTATION>', '</IMPLEMENTATION>')
|
53 |
+
|
54 |
+
if implementation is not None:
|
55 |
+
break
|
56 |
+
|
57 |
+
message = 'You have not provided an implementation yet. Please provide one by surrounding it in the tags <IMPLEMENTATION> and </IMPLEMENTATION>.'
|
58 |
+
|
59 |
+
implementation = implementation.strip()
|
60 |
+
|
61 |
+
# Sometimes the LLM leaves the Markdown formatting in the implementation
|
62 |
+
implementation = implementation.replace('```python', '').replace('```', '').strip()
|
63 |
+
|
64 |
+
implementation = implementation.replace('def send_query(', 'def run(')
|
65 |
+
|
66 |
+
return implementation
|
67 |
+
|
68 |
+
TOOL_PROGRAMMER_PROMPT = '''
|
69 |
+
You are ProtocolProgrammerGPT. Your task is to write a routine that takes a query formatted according to the protocol and returns a response.
|
70 |
+
The routine is a Python file that contains a function "reply". reply takes a single argument, "query", which is a string, and must return a string.
|
71 |
+
Depending on the protocol, the routine might be need to perform some actions before returning the response. The user might provide you with a list of \
|
72 |
+
Python functions you can call to help you with this task. You don't need to worry about importing them, they are already available in the environment.
|
73 |
+
Rules:
|
74 |
+
- The implementation must be written in Python.
|
75 |
+
- You can define any number of helper functions and import any libraries that are part of the Python standard library.
|
76 |
+
- Do not import libraries that are not part of the Python standard library.
|
77 |
+
- Remember to import standard libraries if you need them.
|
78 |
+
- If there is an unexpected error that is not covered by the protocol, throw an exception.\
|
79 |
+
If instead the protocol specifies how to handle the error, return the response according to the protocol's specification.
|
80 |
+
- Do not execute anything (aside from library imports) when the file itself is loaded. I will personally import the file and call the reply function with the task data.
|
81 |
+
Begin by thinking about the implementation and how you would structure the code. \
|
82 |
+
Then, write your implementation by writing a code block that contains the tags <IMPLEMENTATION> and </IMPLEMENTATION>. For example:
|
83 |
+
```python
|
84 |
+
<IMPLEMENTATION>
|
85 |
+
|
86 |
+
def reply(query):
|
87 |
+
...
|
88 |
+
|
89 |
+
</IMPLEMENTATION>
|
90 |
+
'''
|
91 |
+
|
92 |
+
class ReceiverProgrammer:
|
93 |
+
def __init__(self, toolformer : Toolformer):
|
94 |
+
self.toolformer = toolformer
|
95 |
+
|
96 |
+
def write_routine_for_tools(self, tools, protocol_document, additional_info):
|
97 |
+
conversation = self.toolformer.new_conversation(TOOL_PROGRAMMER_PROMPT + additional_info, [], category='programming')
|
98 |
+
|
99 |
+
message = 'Protocol document:\n\n' + protocol_document + '\n\n' + 'Additional functions:\n\n'
|
100 |
+
|
101 |
+
if len(tools) == 0:
|
102 |
+
message += 'No additional functions provided'
|
103 |
+
else:
|
104 |
+
for tool in tools:
|
105 |
+
message += tool.as_documented_python() + '\n\n'
|
106 |
+
|
107 |
+
|
108 |
+
for i in range(5):
|
109 |
+
reply = conversation.chat(message, print_output=True)
|
110 |
+
|
111 |
+
implementation = extract_substring(reply, '<IMPLEMENTATION>', '</IMPLEMENTATION>')
|
112 |
+
|
113 |
+
if implementation is not None:
|
114 |
+
break
|
115 |
+
|
116 |
+
message = 'You have not provided an implementation yet. Please provide one by surrounding it in the tags <IMPLEMENTATION> and </IMPLEMENTATION>.'
|
117 |
+
|
118 |
+
implementation = implementation.strip()
|
119 |
+
|
120 |
+
# Sometimes the LLM leaves the Markdown formatting in the implementation
|
121 |
+
implementation = implementation.replace('```python', '').replace('```', '').strip()
|
122 |
+
|
123 |
+
implementation = implementation.replace('def reply(', 'def run(')
|
124 |
+
|
125 |
+
return implementation
|
querier.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The querier queries a service based on a protocol document.
|
2 |
+
# It receives the protocol document and writes the query that must be performed to the system.
|
3 |
+
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import sys
|
9 |
+
|
10 |
+
from toolformers.base import Tool, StringParameter, parameter_from_openai_api
|
11 |
+
|
12 |
+
#from utils import send_raw_query
|
13 |
+
|
14 |
+
PROTOCOL_QUERIER_PROMPT = 'You are QuerierGPT. You will receive a protocol document detailing how to query a service. Reply with a structured query which can be sent to the service.' \
|
15 |
+
'Only reply with the query itself, with no additional information or escaping. Similarly, do not add any additional whitespace or formatting.'
|
16 |
+
|
17 |
+
def construct_query_description(protocol_document, task_schema, task_data):
|
18 |
+
query_description = ''
|
19 |
+
if protocol_document is not None:
|
20 |
+
query_description += 'Protocol document:\n\n'
|
21 |
+
query_description += protocol_document + '\n\n'
|
22 |
+
query_description += 'JSON schema of the task:\n\n'
|
23 |
+
query_description += 'Input (i.e. what the machine will provide you):\n'
|
24 |
+
query_description += json.dumps(task_schema['input'], indent=2) + '\n\n'
|
25 |
+
query_description += 'Output (i.e. what you have to provide to the machine):\n'
|
26 |
+
query_description += json.dumps(task_schema['output'], indent=2) + '\n\n'
|
27 |
+
query_description += 'JSON data of the task:\n\n'
|
28 |
+
query_description += json.dumps(task_data, indent=2) + '\n\n'
|
29 |
+
|
30 |
+
return query_description
|
31 |
+
|
32 |
+
NL_QUERIER_PROMPT = 'You are NaturalLanguageQuerierGPT. You act as an intermediary between a machine (who has a very specific input and output schema) and an agent (who uses natural language).' \
|
33 |
+
'You will receive a task description (including a schema of the input and output) that the machine uses and the corresponding data. Call the \"sendQuery\" tool with a natural language message where you ask to perform the task according to the data.' \
|
34 |
+
'Make sure to mention all the relevant information. ' \
|
35 |
+
'Do not worry about managing communication, everything is already set up for you. Just focus on asking the right question.' \
|
36 |
+
'The sendQuery tool will return the reply of the service.\n' \
|
37 |
+
'Once you receive the reply, call the \"deliverStructuredOutput\" tool with parameters according to the task\'s output schema. \n' \
|
38 |
+
'Note: you cannot call sendQuery multiple times, so make sure to ask the right question the first time. Similarly, you cannot call deliverStructuredOutput multiple times, so make sure to deliver the right output the first time.' \
|
39 |
+
'If the query fails, do not attempt to send another query.'
|
40 |
+
|
41 |
+
def parse_and_handle_query(query, callback_fn, protocol_id, source):
|
42 |
+
if isinstance(query, dict):
|
43 |
+
query = json.dumps(query)
|
44 |
+
response = callback_fn({
|
45 |
+
"protocolHash": protocol_id,
|
46 |
+
"body": query,
|
47 |
+
"protocolSources": None if source is None else [source]
|
48 |
+
})
|
49 |
+
|
50 |
+
print('Response:', response, type(response))
|
51 |
+
|
52 |
+
if response['status'] == 'success':
|
53 |
+
return response['body']
|
54 |
+
else:
|
55 |
+
return 'Error calling the tool: ' + response['message']
|
56 |
+
|
57 |
+
def get_output_parameters(task_schema):
|
58 |
+
output_schema = task_schema['output']
|
59 |
+
required_parameters = output_schema['required']
|
60 |
+
|
61 |
+
parameters = []
|
62 |
+
|
63 |
+
for parameter_name, parameter_schema in output_schema['properties'].items():
|
64 |
+
parameter = parameter_from_openai_api(parameter_name, parameter_schema, parameter_name in required_parameters)
|
65 |
+
parameters.append(parameter)
|
66 |
+
|
67 |
+
return parameters
|
68 |
+
|
69 |
+
class Querier:
|
70 |
+
def __init__(self, toolformer):
|
71 |
+
self.toolformer = toolformer
|
72 |
+
|
73 |
+
def handle_conversation(self, prompt, message, callback_fn, protocol_id, source, output_parameters):
|
74 |
+
sent_query_counter = 0
|
75 |
+
|
76 |
+
def send_query_internal(query):
|
77 |
+
print('Sending query:', query)
|
78 |
+
nonlocal sent_query_counter
|
79 |
+
sent_query_counter += 1
|
80 |
+
|
81 |
+
if sent_query_counter > 50:
|
82 |
+
# All hope is lost, crash
|
83 |
+
sys.exit(-2)
|
84 |
+
elif sent_query_counter > 10:
|
85 |
+
# LLM is not listening, throw an exception
|
86 |
+
raise Exception('Too many attempts to send queries. Exiting.')
|
87 |
+
elif sent_query_counter > 5:
|
88 |
+
# LLM is not listening, issue a warning
|
89 |
+
return 'You have attempted to send too many queries. Finish the message and allow the user to speak, or the system will crash.'
|
90 |
+
elif sent_query_counter > 1:
|
91 |
+
return 'You have already sent a query. You cannot send another one.'
|
92 |
+
return parse_and_handle_query(query, callback_fn, protocol_id, source)
|
93 |
+
|
94 |
+
send_query_tool = Tool('sendQuery', 'Send a query to the other service based on a protocol document.', [
|
95 |
+
StringParameter('query', 'The query to send to the service', True)
|
96 |
+
], send_query_internal)
|
97 |
+
|
98 |
+
found_output = None
|
99 |
+
registered_output_counter = 0
|
100 |
+
|
101 |
+
def register_output(**kwargs):
|
102 |
+
print('Registering output:', kwargs)
|
103 |
+
|
104 |
+
nonlocal found_output
|
105 |
+
nonlocal registered_output_counter
|
106 |
+
if found_output is not None:
|
107 |
+
registered_output_counter += 1
|
108 |
+
|
109 |
+
if registered_output_counter > 50:
|
110 |
+
# All hope is lost, crash
|
111 |
+
sys.exit(-2)
|
112 |
+
elif registered_output_counter > 10:
|
113 |
+
# LLM is not listening, raise an exception
|
114 |
+
raise Exception('Too many attempts to register outputs. Exiting.')
|
115 |
+
elif registered_output_counter > 5:
|
116 |
+
# LLM is not listening, issue a warning
|
117 |
+
return 'You have attempted to register too many outputs. Finish the message and allow the user to speak, or the system will crash.'
|
118 |
+
elif registered_output_counter > 0:
|
119 |
+
return 'You have already registered an output. You cannot register another one.'
|
120 |
+
|
121 |
+
output = json.dumps(kwargs)
|
122 |
+
|
123 |
+
found_output = output
|
124 |
+
return 'Done'
|
125 |
+
|
126 |
+
register_output_tool = Tool('deliverStructuredOutput', 'Deliver the structured output to the machine.',
|
127 |
+
output_parameters
|
128 |
+
, register_output)
|
129 |
+
|
130 |
+
conversation = self.toolformer.new_conversation(prompt, [send_query_tool, register_output_tool], category='conversation')
|
131 |
+
|
132 |
+
for i in range(5):
|
133 |
+
conversation.chat(message, print_output=True)
|
134 |
+
|
135 |
+
if found_output is not None:
|
136 |
+
break
|
137 |
+
|
138 |
+
# If we haven't sent a query yet, we can't proceed
|
139 |
+
if sent_query_counter == 0:
|
140 |
+
message = 'You must send a query before delivering the structured output.'
|
141 |
+
elif found_output is None:
|
142 |
+
message = 'You must deliver the structured output.'
|
143 |
+
|
144 |
+
return found_output
|
145 |
+
|
146 |
+
#def send_query_with_protocol(self, storage, task_schema, task_data, target_node, protocol_id, source):
|
147 |
+
# base_folder = Path(os.environ.get('STORAGE_PATH')) / 'protocol_documents'
|
148 |
+
# protocol_document = storage.load_protocol_document(base_folder, protocol_id)
|
149 |
+
# query_description = construct_query_description(protocol_document, task_schema, task_data)
|
150 |
+
# output_parameters = get_output_parameters(task_schema)
|
151 |
+
#
|
152 |
+
# return self.handle_conversation(PROTOCOL_QUERIER_PROMPT, query_description, target_node, protocol_id, source, output_parameters)
|
153 |
+
|
154 |
+
def send_query_without_protocol(self, task_schema, task_data, callback_fn):
|
155 |
+
query_description = construct_query_description(None, task_schema, task_data)
|
156 |
+
output_parameters = get_output_parameters(task_schema)
|
157 |
+
|
158 |
+
return self.handle_conversation(NL_QUERIER_PROMPT, query_description, callback_fn, None, None, output_parameters)
|
requirements.txt
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
asknews==0.7.51
|
2 |
+
camel-ai==0.2.6
|
3 |
+
google-ai-generativelanguage==0.6.4
|
4 |
+
google-api-core==2.22.0
|
5 |
+
google-api-python-client==2.151.0
|
6 |
+
google-auth==2.36.0
|
7 |
+
google-auth-httplib2==0.2.0
|
8 |
+
google-cloud-core==2.4.1
|
9 |
+
google-cloud-storage==2.18.2
|
10 |
+
google-cloud-vision==3.8.0
|
11 |
+
google-crc32c==1.6.0
|
12 |
+
google-generativeai==0.6.0
|
13 |
+
gradio==5.5.0
|
14 |
+
gradio_client==1.4.2
|
15 |
+
huggingface-hub==0.26.2
|
16 |
+
langchain==0.3.7
|
17 |
+
langchain-anthropic==0.2.4
|
18 |
+
langchain-community==0.3.5
|
19 |
+
langchain-core==0.3.15
|
20 |
+
langchain-text-splitters==0.3.2
|
21 |
+
langgraph==0.2.45
|
22 |
+
langgraph-checkpoint==2.0.2
|
23 |
+
langgraph-sdk==0.1.35
|
24 |
+
python-dotenv==1.0.1
|
25 |
+
requests-oauthlib==1.3.1
|
26 |
+
sseclient-py==1.8.0
|
27 |
+
torch==2.5.1
|
28 |
+
transformers==4.46.2
|
responder.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The responder is a special toolformer that replies to a service based on a protocol document.
|
2 |
+
# It receives the protocol document and writes the response that must be sent to the system.
|
3 |
+
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
from toolformers.base import Toolformer
|
9 |
+
|
10 |
+
# TODO: A tool to declare an error?
|
11 |
+
|
12 |
+
|
13 |
+
PROTOCOL_RESPONDER_PROMPT = 'You are ResponderGPT. You will receive a protocol document detailing how to respond to a query. '\
|
14 |
+
'Use the provided functions to execute what is requested and provide the response according to the protocol\'s specification. ' \
|
15 |
+
'Only reply with the response itself, with no additional information or escaping. Similarly, do not add any additional whitespace or formatting.'# \
|
16 |
+
# 'If you do not have enough information to reply, or if you cannot execute the request, reply with "ERROR" (without quotes).'
|
17 |
+
|
18 |
+
NL_RESPONDER_PROMPT = 'You are NaturalLanguageResponderGPT. You will receive a query from a user. ' \
|
19 |
+
'Use the provided functions to execute what is requested and reply with a response (in natural language). ' \
|
20 |
+
'Important: the user does not have the capacity to respond to follow-up questions, so if you think you have enough information to reply/execute the actions, do so.'
|
21 |
+
#'If you do not have enough information to reply, if you cannot execute the request, or if the request is invalid, reply with "ERROR" (without quotes).' \
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
class Responder:
|
26 |
+
def __init__(self, toolformer : Toolformer):
|
27 |
+
self.toolformer = toolformer
|
28 |
+
|
29 |
+
def reply_with_protocol_document(self, query, protocol_document, tools, additional_info):
|
30 |
+
print('===NL RESPONDER (WITH PROTOCOL)===')
|
31 |
+
|
32 |
+
conversation = self.toolformer.new_conversation(PROTOCOL_RESPONDER_PROMPT + additional_info, tools, category='conversation')
|
33 |
+
|
34 |
+
prompt = 'The protocol is the following:\n\n' + protocol_document + '\n\nThe query is the following:' + query
|
35 |
+
|
36 |
+
reply = conversation.chat(prompt, print_output=True)
|
37 |
+
|
38 |
+
print('======')
|
39 |
+
|
40 |
+
if 'error' in reply.lower().strip()[-10:]:
|
41 |
+
return {
|
42 |
+
'status': 'error',
|
43 |
+
'message': 'Error in the response'
|
44 |
+
}
|
45 |
+
|
46 |
+
return {
|
47 |
+
'status': 'success',
|
48 |
+
'body': reply
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
def reply_to_nl_query(self, query, tools, additional_info):
|
53 |
+
print('===NL RESPONDER (NO PROTOCOL)===')
|
54 |
+
print(NL_RESPONDER_PROMPT + additional_info)
|
55 |
+
print([tool.name for tool in tools])
|
56 |
+
|
57 |
+
conversation = self.toolformer.new_conversation(NL_RESPONDER_PROMPT + additional_info, tools, category='conversation')
|
58 |
+
|
59 |
+
print('Created conversation')
|
60 |
+
try:
|
61 |
+
reply = conversation.chat(query, print_output=True)
|
62 |
+
except Exception as e:
|
63 |
+
# Print traceback
|
64 |
+
import traceback
|
65 |
+
traceback.print_exc()
|
66 |
+
raise e
|
67 |
+
print('======')
|
68 |
+
|
69 |
+
if 'error' in reply.lower().strip()[-10:]:
|
70 |
+
return {
|
71 |
+
'status': 'error',
|
72 |
+
}
|
73 |
+
|
74 |
+
return {
|
75 |
+
'status': 'success',
|
76 |
+
'body': reply
|
77 |
+
}
|
78 |
+
|
79 |
+
|
80 |
+
def reply_to_query(self, query, protocol_id, tools, additional_info):
|
81 |
+
print('Additional info:', additional_info)
|
82 |
+
if protocol_id is None:
|
83 |
+
return self.reply_to_nl_query(query, tools, additional_info)
|
84 |
+
#else:
|
85 |
+
# base_folder = Path(os.environ.get('STORAGE_PATH')) / 'protocol_documents'
|
86 |
+
# protocol_document = self.memory.load_protocol_document(base_folder, protocol_id)
|
87 |
+
# return self.reply_with_protocol_document(query, protocol_document, tools, additional_info)
|
toolformers/base.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
|
3 |
+
import json
|
4 |
+
|
5 |
+
from google.generativeai.types import CallableFunctionDeclaration
|
6 |
+
import google.generativeai.types.content_types as content_types
|
7 |
+
|
8 |
+
from utils import add_params_and_annotations
|
9 |
+
|
10 |
+
class Parameter:
|
11 |
+
def __init__(self, name, description, required):
|
12 |
+
self.name = name
|
13 |
+
self.description = description
|
14 |
+
self.required = required
|
15 |
+
|
16 |
+
def as_openai_info(self):
|
17 |
+
pass
|
18 |
+
|
19 |
+
def as_standard_api(self):
|
20 |
+
pass
|
21 |
+
|
22 |
+
class StringParameter(Parameter):
|
23 |
+
def __init__(self, name, description, required):
|
24 |
+
super().__init__(name, description, required)
|
25 |
+
|
26 |
+
def as_openai_info(self):
|
27 |
+
return {
|
28 |
+
"type": "string",
|
29 |
+
"name": self.name,
|
30 |
+
"description": self.description
|
31 |
+
}
|
32 |
+
|
33 |
+
def as_standard_api(self):
|
34 |
+
return {
|
35 |
+
"type": "string",
|
36 |
+
"name": self.name,
|
37 |
+
"description": self.description,
|
38 |
+
"required": self.required
|
39 |
+
}
|
40 |
+
|
41 |
+
def as_natural_language(self):
|
42 |
+
return f'{self.name} (string{", required" if self.required else ""}): {self.description}.'
|
43 |
+
|
44 |
+
def as_documented_python(self):
|
45 |
+
return f'{self.name} (str{", required" if self.required else ""}): {self.description}.'
|
46 |
+
|
47 |
+
def as_gemini_tool(self):
|
48 |
+
return {
|
49 |
+
'type': 'string',
|
50 |
+
'description': self.description
|
51 |
+
}
|
52 |
+
|
53 |
+
@staticmethod
|
54 |
+
def from_standard_api(api_info):
|
55 |
+
return StringParameter(api_info["name"], api_info["description"], api_info["required"])
|
56 |
+
|
57 |
+
class EnumParameter(Parameter):
|
58 |
+
def __init__(self, name, description, values, required):
|
59 |
+
super().__init__(name, description, required)
|
60 |
+
self.values = values
|
61 |
+
|
62 |
+
def as_openai_info(self):
|
63 |
+
return {
|
64 |
+
"type": "string",
|
65 |
+
"description": self.description,
|
66 |
+
"values": self.values
|
67 |
+
}
|
68 |
+
|
69 |
+
def as_standard_api(self):
|
70 |
+
return {
|
71 |
+
"type": "enum",
|
72 |
+
"name": self.name,
|
73 |
+
"description": self.description,
|
74 |
+
"values": self.values,
|
75 |
+
"required": self.required
|
76 |
+
}
|
77 |
+
|
78 |
+
def as_natural_language(self):
|
79 |
+
return f'{self.name} (enum{", required" if self.required else ""}): {self.description}. Possible values: {", ".join(self.values)}'
|
80 |
+
|
81 |
+
def as_documented_python(self):
|
82 |
+
return f'{self.name} (str{", required" if self.required else ""}): {self.description}. Possible values: {", ".join(self.values)}'
|
83 |
+
|
84 |
+
def as_gemini_tool(self):
|
85 |
+
return {
|
86 |
+
'description': self.description,
|
87 |
+
'type': 'string',
|
88 |
+
'enum': self.values
|
89 |
+
}
|
90 |
+
|
91 |
+
@staticmethod
|
92 |
+
def from_standard_api(api_info):
|
93 |
+
return EnumParameter(api_info["name"], api_info["description"], api_info["values"], api_info["required"])
|
94 |
+
|
95 |
+
class NumberParameter(Parameter):
|
96 |
+
def __init__(self, name, description, required):
|
97 |
+
super().__init__(name, description, required)
|
98 |
+
|
99 |
+
def as_openai_info(self):
|
100 |
+
return {
|
101 |
+
"type": "number",
|
102 |
+
"description": self.description
|
103 |
+
}
|
104 |
+
|
105 |
+
def as_standard_api(self):
|
106 |
+
return {
|
107 |
+
"type": "number",
|
108 |
+
"name": self.name,
|
109 |
+
"description": self.description,
|
110 |
+
"required": self.required
|
111 |
+
}
|
112 |
+
|
113 |
+
def as_natural_language(self):
|
114 |
+
return f'{self.name} (number): {self.description}'
|
115 |
+
|
116 |
+
def as_documented_python(self):
|
117 |
+
return f'{self.name} (number): {self.description}'
|
118 |
+
|
119 |
+
def as_gemini_tool(self):
|
120 |
+
return {
|
121 |
+
'description': self.description,
|
122 |
+
'type': 'number'
|
123 |
+
}
|
124 |
+
|
125 |
+
class ArrayParameter(Parameter):
|
126 |
+
def __init__(self, name, description, required, item_schema):
|
127 |
+
super().__init__(name, description, required)
|
128 |
+
self.item_schema = item_schema
|
129 |
+
|
130 |
+
def as_openai_info(self):
|
131 |
+
return {
|
132 |
+
"type": "array",
|
133 |
+
"description": self.description,
|
134 |
+
"items": self.item_schema
|
135 |
+
}
|
136 |
+
|
137 |
+
def as_standard_api(self):
|
138 |
+
return {
|
139 |
+
"type": "array",
|
140 |
+
"name": self.name,
|
141 |
+
"description": self.description,
|
142 |
+
"required": self.required,
|
143 |
+
"item_schema": self.item_schema
|
144 |
+
}
|
145 |
+
|
146 |
+
def as_natural_language(self):
|
147 |
+
return f'{self.name} (array): {self.description}. Each item should follow the JSON schema: {json.dumps(self.item_schema)}'
|
148 |
+
|
149 |
+
def as_documented_python(self):
|
150 |
+
return f'{self.name} (list): {self.description}. Each item should follow the JSON schema: {json.dumps(self.item_schema)}'
|
151 |
+
|
152 |
+
def as_gemini_tool(self):
|
153 |
+
return {
|
154 |
+
'description': self.description,
|
155 |
+
'type': 'array',
|
156 |
+
'items': self.item_schema
|
157 |
+
}
|
158 |
+
|
159 |
+
def parameter_from_openai_api(parameter_name, schema, required):
|
160 |
+
if 'enum' in schema:
|
161 |
+
return EnumParameter(parameter_name, schema['description'], schema['enum'], required)
|
162 |
+
elif schema['type'] == 'string':
|
163 |
+
return StringParameter(parameter_name, schema['description'], required)
|
164 |
+
elif schema['type'] == 'number':
|
165 |
+
return NumberParameter(parameter_name, schema['description'], required)
|
166 |
+
elif schema['type'] == 'array':
|
167 |
+
return ArrayParameter(parameter_name, schema['description'], required, schema['items'])
|
168 |
+
else:
|
169 |
+
raise ValueError(f'Unknown parameter type: {schema["type"]}')
|
170 |
+
|
171 |
+
class Tool:
|
172 |
+
def __init__(self, name, description, parameters, function, output_schema=None):
|
173 |
+
self.name = name
|
174 |
+
self.description = description
|
175 |
+
self.parameters = parameters
|
176 |
+
self.function = function
|
177 |
+
self.output_schema = output_schema
|
178 |
+
|
179 |
+
def call_tool_for_toolformer(self, *args, **kwargs):
|
180 |
+
print(f'Toolformer called tool {self.name} with args {args} and kwargs {kwargs}')
|
181 |
+
# Unlike a call from a routine, this call catches exceptions and returns them as strings
|
182 |
+
try:
|
183 |
+
tool_reply = self.function(*args, **kwargs)
|
184 |
+
print(f'Tool {self.name} returned: {tool_reply}')
|
185 |
+
return tool_reply
|
186 |
+
except Exception as e:
|
187 |
+
print(f'Tool {self.name} failed with exception: {e}')
|
188 |
+
return 'Tool call failed: ' + str(e)
|
189 |
+
|
190 |
+
def as_openai_info(self):
|
191 |
+
return {
|
192 |
+
"type": "function",
|
193 |
+
"function": {
|
194 |
+
"name": self.name,
|
195 |
+
"description": self.description,
|
196 |
+
"parameters": {
|
197 |
+
"type" : "object",
|
198 |
+
"properties": {parameter.name : parameter.as_openai_info() for parameter in self.parameters},
|
199 |
+
"required": [parameter.name for parameter in self.parameters if parameter.required]
|
200 |
+
}
|
201 |
+
}
|
202 |
+
}
|
203 |
+
|
204 |
+
def as_gemini_tool(self) -> CallableFunctionDeclaration:
|
205 |
+
if len(self.parameters) == 0:
|
206 |
+
parameters = None
|
207 |
+
else:
|
208 |
+
parameters = {
|
209 |
+
'type': 'object',
|
210 |
+
'properties': {parameter.name: parameter.as_gemini_tool() for parameter in self.parameters},
|
211 |
+
'required': [parameter.name for parameter in self.parameters if parameter.required]
|
212 |
+
}
|
213 |
+
return content_types.Tool([CallableFunctionDeclaration(
|
214 |
+
name=self.name,
|
215 |
+
description=self.description,
|
216 |
+
parameters=parameters,
|
217 |
+
function=self.call_tool_for_toolformer
|
218 |
+
)])
|
219 |
+
|
220 |
+
def as_llama_schema(self):
|
221 |
+
schema = {
|
222 |
+
'name': self.name,
|
223 |
+
'description': self.description,
|
224 |
+
'parameters': {parameter.name : parameter.as_openai_info() for parameter in self.parameters},
|
225 |
+
'required': [parameter.name for parameter in self.parameters if parameter.required]
|
226 |
+
}
|
227 |
+
|
228 |
+
if self.output_schema is not None:
|
229 |
+
schema['output_schema'] = self.output_schema
|
230 |
+
|
231 |
+
return schema
|
232 |
+
|
233 |
+
def as_natural_language(self):
|
234 |
+
print('Converting to natural language')
|
235 |
+
print('Number of parameters:', len(self.parameters))
|
236 |
+
nl = f'Function {self.name}: {self.description}. Parameters:\n'
|
237 |
+
|
238 |
+
if len(self.parameters) == 0:
|
239 |
+
nl += 'No parameters.'
|
240 |
+
else:
|
241 |
+
for parameter in self.parameters:
|
242 |
+
nl += '\t' + parameter.as_natural_language() + '\n'
|
243 |
+
|
244 |
+
if self.output_schema is not None:
|
245 |
+
nl += f'\Returns a dictionary with schema: {json.dumps(self.output_schema, indent=2)}'
|
246 |
+
|
247 |
+
return nl
|
248 |
+
|
249 |
+
def as_standard_api(self):
|
250 |
+
return {
|
251 |
+
"name": self.name,
|
252 |
+
"description": self.description,
|
253 |
+
"parameters": [parameter.as_standard_api() for parameter in self.parameters]
|
254 |
+
}
|
255 |
+
|
256 |
+
def as_documented_python(self):
|
257 |
+
documented_python = f'Tool {self.name}:\n\n{self.description}\nParameters:\n'
|
258 |
+
|
259 |
+
if len(self.parameters) == 0:
|
260 |
+
documented_python += 'No parameters.'
|
261 |
+
else:
|
262 |
+
for parameter in self.parameters:
|
263 |
+
documented_python += '\t' + parameter.as_documented_python() + '\n'
|
264 |
+
|
265 |
+
if self.output_schema is not None:
|
266 |
+
documented_python += f'\Returns a dictionary with schema: {json.dumps(self.output_schema, indent=2)}'
|
267 |
+
|
268 |
+
return documented_python
|
269 |
+
|
270 |
+
def as_executable_function(self):
|
271 |
+
# Create an actual function that can be called
|
272 |
+
def f(*args, **kwargs):
|
273 |
+
print('Routine called tool', self.name, 'with args', args, 'and kwargs', kwargs)
|
274 |
+
response = self.function(*args, **kwargs)
|
275 |
+
print('Tool', self.name, 'returned:', response)
|
276 |
+
return response
|
277 |
+
|
278 |
+
return f
|
279 |
+
|
280 |
+
def as_annotated_function(self):
|
281 |
+
def wrapped_fn(*args, **kwargs):
|
282 |
+
return self.call_tool_for_toolformer(*args, **kwargs)
|
283 |
+
|
284 |
+
parsed_parameters = {}
|
285 |
+
|
286 |
+
description = self.description
|
287 |
+
|
288 |
+
for parameter_name, parameter_schema in self.as_openai_info()['function']['parameters']['properties'].items():
|
289 |
+
if parameter_schema['type'] == 'string':
|
290 |
+
parsed_parameters[parameter_name] = (str, parameter_schema['description'])
|
291 |
+
elif parameter_schema['type'] == 'number':
|
292 |
+
parsed_parameters[parameter_name] = (float, parameter_schema['description'])
|
293 |
+
elif parameter_schema['type'] == 'object':
|
294 |
+
parsed_parameters[parameter_name] = (dict, parameter_schema['description'])
|
295 |
+
|
296 |
+
description += f'\n{parameter_name} has the schema:\n' + json.dumps(parameter_schema) + '\n'
|
297 |
+
else:
|
298 |
+
raise ValueError(f'Unknown parameter type: {parameter_schema["type"]}')
|
299 |
+
|
300 |
+
return_type = type(None)
|
301 |
+
|
302 |
+
if self.output_schema is not None:
|
303 |
+
#description += '\nOutput schema:\n' + json.dumps(self.output_schema)
|
304 |
+
|
305 |
+
if self.output_schema['type'] == 'string':
|
306 |
+
return_type = str
|
307 |
+
elif self.output_schema['type'] == 'number':
|
308 |
+
return_type = float
|
309 |
+
elif self.output_schema['type'] == 'object':
|
310 |
+
return_type = dict
|
311 |
+
else:
|
312 |
+
raise ValueError(f'Unknown output type: {self.output_schema["type"]}')
|
313 |
+
|
314 |
+
return add_params_and_annotations(
|
315 |
+
self.name, description, parsed_parameters, return_type)(wrapped_fn)
|
316 |
+
|
317 |
+
@staticmethod
|
318 |
+
def from_openai_info(info, func):
|
319 |
+
parameters = [parameter_from_openai_api(name, schema, name in info['function']['parameters']['required']) for name, schema in info['function']['parameters']['properties'].items()]
|
320 |
+
return Tool(info['function']['name'], info['function']['description'], parameters, func)
|
321 |
+
|
322 |
+
|
323 |
+
class Conversation(ABC):
|
324 |
+
@abstractmethod
|
325 |
+
def chat(self, message, role='user', print_output=True):
|
326 |
+
pass
|
327 |
+
|
328 |
+
class Toolformer(ABC):
|
329 |
+
@abstractmethod
|
330 |
+
def new_conversation(self, prompt, tools, category=None) -> Conversation:
|
331 |
+
pass
|
332 |
+
|
toolformers/camel.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import os
|
3 |
+
from typing import List
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
from toolformers.base import Conversation, Toolformer, Tool
|
7 |
+
from camel.messages import BaseMessage
|
8 |
+
from camel.models import ModelFactory
|
9 |
+
from camel.types import ModelPlatformType, ModelType
|
10 |
+
from camel.messages import BaseMessage as bm
|
11 |
+
from camel.agents import ChatAgent
|
12 |
+
from camel.toolkits.function_tool import FunctionTool
|
13 |
+
from camel.configs.openai_config import ChatGPTConfig
|
14 |
+
|
15 |
+
class CamelConversation(Conversation):
|
16 |
+
def __init__(self, toolformer, agent, category=None):
|
17 |
+
self.toolformer = toolformer
|
18 |
+
self.agent = agent
|
19 |
+
self.category = category
|
20 |
+
|
21 |
+
def chat(self, message, role='user', print_output=True):
|
22 |
+
agent_id = os.environ.get('AGENT_ID', None)
|
23 |
+
|
24 |
+
start_time = datetime.datetime.now()
|
25 |
+
|
26 |
+
if role == 'user':
|
27 |
+
formatted_message = BaseMessage.make_user_message('user', message)
|
28 |
+
elif role == 'assistant':
|
29 |
+
formatted_message = BaseMessage.make_assistant_message('assistant', message)
|
30 |
+
else:
|
31 |
+
raise ValueError('Role must be either "user" or "assistant".')
|
32 |
+
|
33 |
+
response = self.agent.step(formatted_message)
|
34 |
+
|
35 |
+
reply = response.msg.content
|
36 |
+
|
37 |
+
if print_output:
|
38 |
+
print(reply)
|
39 |
+
|
40 |
+
return reply
|
41 |
+
|
42 |
+
class CamelToolformer(Toolformer):
|
43 |
+
def __init__(self, model_platform, model_type, model_config_dict, name=None):
|
44 |
+
self.model_platform = model_platform
|
45 |
+
self.model_type = model_type
|
46 |
+
self.model_config_dict = model_config_dict
|
47 |
+
self._name = name
|
48 |
+
|
49 |
+
@property
|
50 |
+
def name(self):
|
51 |
+
if self._name is None:
|
52 |
+
return f'{self.model_platform.value}_{self.model_type.value}'
|
53 |
+
else:
|
54 |
+
return self._name
|
55 |
+
|
56 |
+
def new_conversation(self, prompt, tools : List[Tool], category=None) -> Conversation:
|
57 |
+
model = ModelFactory.create(
|
58 |
+
model_platform=self.model_platform,
|
59 |
+
model_type=self.model_type,
|
60 |
+
model_config_dict=self.model_config_dict
|
61 |
+
)
|
62 |
+
|
63 |
+
agent = ChatAgent(
|
64 |
+
model=model,
|
65 |
+
system_message=bm.make_assistant_message('system', prompt),
|
66 |
+
tools=[FunctionTool(tool.call_tool_for_toolformer, openai_tool_schema=tool.as_openai_info()) for tool in tools]
|
67 |
+
)
|
68 |
+
|
69 |
+
return CamelConversation(self, agent, category)
|
70 |
+
|
71 |
+
def make_openai_toolformer(model_type_internal):
|
72 |
+
if model_type_internal == 'gpt-4o':
|
73 |
+
model_type = ModelType.GPT_4O
|
74 |
+
elif model_type_internal == 'gpt-4o-mini':
|
75 |
+
model_type = ModelType.GPT_4O_MINI
|
76 |
+
else:
|
77 |
+
raise ValueError('Model type must be either "gpt-4o" or "gpt-4o-mini".')
|
78 |
+
|
79 |
+
#formatted_tools = [FunctionTool(tool.call_tool_for_toolformer, tool.as_openai_info()) for tool in tools]
|
80 |
+
|
81 |
+
return CamelToolformer(
|
82 |
+
model_platform=ModelPlatformType.OPENAI,
|
83 |
+
model_type=model_type,
|
84 |
+
model_config_dict=ChatGPTConfig(temperature=0.2).as_dict(),
|
85 |
+
name=model_type_internal
|
86 |
+
)
|
toolformers/gemini.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import os
|
3 |
+
from random import random
|
4 |
+
import time
|
5 |
+
import traceback
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
from toolformers.base import Conversation, Tool, Toolformer
|
9 |
+
|
10 |
+
import google.generativeai as genai
|
11 |
+
from google.generativeai.generative_models import ChatSession
|
12 |
+
|
13 |
+
genai.configure(api_key=os.environ['GOOGLE_API_KEY'])
|
14 |
+
|
15 |
+
class GeminiConversation(Conversation):
|
16 |
+
def __init__(self, model_name, chat_agent : ChatSession, category=None):
|
17 |
+
self.model_name = model_name
|
18 |
+
self.chat_agent = chat_agent
|
19 |
+
self.category = category
|
20 |
+
|
21 |
+
def chat(self, message, role='user', print_output=True):
|
22 |
+
agent_id = os.environ.get('AGENT_ID', None)
|
23 |
+
time_start = datetime.datetime.now()
|
24 |
+
|
25 |
+
exponential_backoff_lower = 30
|
26 |
+
exponential_backoff_higher = 60
|
27 |
+
for i in range(5):
|
28 |
+
try:
|
29 |
+
response = self.chat_agent.send_message({
|
30 |
+
'role': role,
|
31 |
+
'parts': [
|
32 |
+
message
|
33 |
+
]
|
34 |
+
})
|
35 |
+
break
|
36 |
+
except Exception as e:
|
37 |
+
print(e)
|
38 |
+
if '429' in str(e):
|
39 |
+
print('Rate limit exceeded. Waiting with random exponential backoff.')
|
40 |
+
if i < 4:
|
41 |
+
time.sleep(random() * (exponential_backoff_higher - exponential_backoff_lower) + exponential_backoff_lower)
|
42 |
+
exponential_backoff_lower *= 2
|
43 |
+
exponential_backoff_higher *= 2
|
44 |
+
elif 'candidates[0]' in traceback.format_exc():
|
45 |
+
# When Gemini has nothing to say, it raises an error with this message
|
46 |
+
print('No response')
|
47 |
+
return 'No response'
|
48 |
+
elif '500' in str(e):
|
49 |
+
# Sometimes Gemini just decides to return a 500 error for absolutely no reason. Retry.
|
50 |
+
print('500 error')
|
51 |
+
time.sleep(5)
|
52 |
+
traceback.print_exc()
|
53 |
+
else:
|
54 |
+
raise e
|
55 |
+
|
56 |
+
time_end = datetime.datetime.now()
|
57 |
+
|
58 |
+
usage_info = {
|
59 |
+
'prompt_tokens': response.usage_metadata.prompt_token_count,
|
60 |
+
'completion_tokens': response.usage_metadata.candidates_token_count
|
61 |
+
}
|
62 |
+
|
63 |
+
#send_usage_to_db(
|
64 |
+
# usage_info,
|
65 |
+
# time_start,
|
66 |
+
# time_end,
|
67 |
+
# agent_id,
|
68 |
+
# self.category,
|
69 |
+
# self.model_name
|
70 |
+
#)
|
71 |
+
|
72 |
+
reply = response.text
|
73 |
+
|
74 |
+
if print_output:
|
75 |
+
print(reply)
|
76 |
+
|
77 |
+
return reply
|
78 |
+
|
79 |
+
class GeminiToolformer(Toolformer):
|
80 |
+
def __init__(self, model_name):
|
81 |
+
self.model_name = model_name
|
82 |
+
|
83 |
+
def new_conversation(self, system_prompt, tools : List[Tool], category=None) -> Conversation:
|
84 |
+
print('Tools:')
|
85 |
+
print('\n'.join([str(tool.as_openai_info()) for tool in tools]))
|
86 |
+
model = genai.GenerativeModel(
|
87 |
+
model_name=self.model_name,
|
88 |
+
system_instruction=system_prompt,
|
89 |
+
tools=[tool.as_gemini_tool() for tool in tools]
|
90 |
+
)
|
91 |
+
|
92 |
+
chat = model.start_chat(enable_automatic_function_calling=True)
|
93 |
+
|
94 |
+
return GeminiConversation(self.model_name, chat, category)
|
95 |
+
|
96 |
+
def make_gemini_toolformer(model_name):
|
97 |
+
if model_name not in ['gemini-1.5-flash', 'gemini-1.5-pro']:
|
98 |
+
raise ValueError(f"Unknown model name: {model_name}")
|
99 |
+
|
100 |
+
return GeminiToolformer(model_name)
|
toolformers/huggingface_agent.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import InferenceClient
|
2 |
+
from transformers import ReactCodeAgent
|
3 |
+
from transformers import tool as function_to_tool
|
4 |
+
|
5 |
+
from toolformers.base import Conversation, Toolformer
|
6 |
+
|
7 |
+
class HuggingFaceConversation(Conversation):
|
8 |
+
def __init__(self, agent : ReactCodeAgent, prompt, category=None):
|
9 |
+
self.agent = agent
|
10 |
+
self.messages = [('system', prompt)]
|
11 |
+
self.category = category
|
12 |
+
|
13 |
+
def chat(self, message, role='user', print_output=True) -> str:
|
14 |
+
self.messages.append((role, message))
|
15 |
+
|
16 |
+
final_prompt = 'For context, here are the previous messages in the conversation:\n\n'
|
17 |
+
|
18 |
+
for role, message in self.messages:
|
19 |
+
final_prompt += f'{role.capitalize()}: {message}\n'
|
20 |
+
|
21 |
+
final_prompt += "Don't worry, you don't need to use the same format to reply. Stick with the Task:/Action:/etc. format.\n\n"
|
22 |
+
|
23 |
+
response = self.agent.run(final_prompt)
|
24 |
+
print(response)
|
25 |
+
return response
|
26 |
+
|
27 |
+
class HuggingFaceToolformer(Toolformer):
|
28 |
+
def __init__(self, model_name, max_tokens=2000):
|
29 |
+
self.model = InferenceClient(model=model_name)
|
30 |
+
self.max_tokens = max_tokens
|
31 |
+
|
32 |
+
def llm_engine(self, messages, stop_sequences=["Task"]) -> str:
|
33 |
+
response = self.model.chat_completion(messages, stop=stop_sequences, max_tokens=self.max_tokens)
|
34 |
+
answer = response.choices[0].message.content
|
35 |
+
return answer
|
36 |
+
|
37 |
+
def new_conversation(self, prompt, tools, category=None):
|
38 |
+
agent = ReactCodeAgent(tools=[function_to_tool(tool.as_annotated_function()) for tool in tools], llm_engine=self.llm_engine)
|
39 |
+
return HuggingFaceConversation(agent, prompt, category=category)
|
toolformers/langchain_agent.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
# Import relevant functionality
|
4 |
+
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
5 |
+
from langgraph.checkpoint.memory import MemorySaver
|
6 |
+
from langgraph.prebuilt import create_react_agent
|
7 |
+
from langchain_anthropic import ChatAnthropic
|
8 |
+
|
9 |
+
import sys
|
10 |
+
sys.path.append('.')
|
11 |
+
from toolformers.base import Tool as AgoraTool
|
12 |
+
|
13 |
+
|
14 |
+
from langchain_core.tools import tool as function_to_tool
|
15 |
+
|
16 |
+
from toolformers.base import StringParameter, Toolformer, Conversation
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
class LangChainConversation(Conversation):
|
21 |
+
def __init__(self, agent, messages, category=None):
|
22 |
+
self.agent = agent
|
23 |
+
self.messages = messages
|
24 |
+
self.category = category
|
25 |
+
|
26 |
+
def chat(self, message, role='user', print_output=True) -> str:
|
27 |
+
self.messages.append(HumanMessage(content=message))
|
28 |
+
final_message = ''
|
29 |
+
for chunk in self.agent.stream({"messages": self.messages}, stream_mode="values"):
|
30 |
+
print(chunk)
|
31 |
+
print("----")
|
32 |
+
for message in chunk['messages']:
|
33 |
+
if isinstance(message, AIMessage):
|
34 |
+
content = message.content
|
35 |
+
if isinstance(content, str):
|
36 |
+
final_message += content
|
37 |
+
else:
|
38 |
+
for content_chunk in content:
|
39 |
+
if isinstance(content_chunk, str):
|
40 |
+
final_message += content_chunk
|
41 |
+
#final_message += chunk['agent']['messages'].content
|
42 |
+
|
43 |
+
self.messages.append(AIMessage(content=final_message))
|
44 |
+
#print(final_message)
|
45 |
+
|
46 |
+
return final_message
|
47 |
+
|
48 |
+
class LangChainAnthropicToolformer(Toolformer):
|
49 |
+
def __init__(self, model_name, api_key):
|
50 |
+
self.model_name = model_name
|
51 |
+
self.api_key = api_key
|
52 |
+
|
53 |
+
def new_conversation(self, prompt, tools, category=None):
|
54 |
+
tools = [function_to_tool(tool.as_annotated_function()) for tool in tools]
|
55 |
+
model = ChatAnthropic(model_name=self.model_name, api_key=self.api_key)
|
56 |
+
agent_executor = create_react_agent(model, tools)
|
57 |
+
|
58 |
+
return LangChainConversation(agent_executor, [SystemMessage(prompt)], category)
|
59 |
+
|
60 |
+
|
61 |
+
#weather_tool = AgoraTool("WeatherForecastAPI", "A simple tool that returns the weather", [StringParameter(
|
62 |
+
# name="location",
|
63 |
+
# description="The name of the location for which the weather forecast is requested.",
|
64 |
+
# required=True
|
65 |
+
#)], lambda location: 'Sunny', {
|
66 |
+
# "type": "string"
|
67 |
+
#})
|
68 |
+
#
|
69 |
+
#tools = [agora_tool_to_langchain(weather_tool)]
|
70 |
+
#toolformer = LangChainToolformer("claude-3-sonnet-20240229", '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')
|
71 |
+
#conversation = toolformer.new_conversation('You are a weather bot', [weather_tool])
|
72 |
+
#
|
73 |
+
#print(conversation.chat('What is the weather in San Francisco?'))
|
toolformers/sambanova/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .core import SambanovaToolformer
|
toolformers/sambanova/api_gateway.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
from langchain_community.llms.sambanova import SambaStudio
|
7 |
+
from langchain_core.language_models.llms import LLM
|
8 |
+
|
9 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
10 |
+
utils_dir = os.path.abspath(os.path.join(current_dir, '..'))
|
11 |
+
repo_dir = os.path.abspath(os.path.join(utils_dir, '..'))
|
12 |
+
sys.path.append(utils_dir)
|
13 |
+
sys.path.append(repo_dir)
|
14 |
+
|
15 |
+
from toolformers.sambanova.sambanova_langchain import SambaNovaCloud
|
16 |
+
|
17 |
+
EMBEDDING_MODEL = 'intfloat/e5-large-v2'
|
18 |
+
NORMALIZE_EMBEDDINGS = True
|
19 |
+
|
20 |
+
# Configure the logger
|
21 |
+
logging.basicConfig(
|
22 |
+
level=logging.INFO,
|
23 |
+
format='%(asctime)s [%(levelname)s] - %(message)s',
|
24 |
+
handlers=[
|
25 |
+
logging.StreamHandler(),
|
26 |
+
],
|
27 |
+
)
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
class APIGateway:
|
32 |
+
@staticmethod
|
33 |
+
def load_llm(
|
34 |
+
type: str,
|
35 |
+
streaming: bool = False,
|
36 |
+
coe: bool = False,
|
37 |
+
do_sample: Optional[bool] = None,
|
38 |
+
max_tokens_to_generate: Optional[int] = None,
|
39 |
+
temperature: Optional[float] = None,
|
40 |
+
select_expert: Optional[str] = None,
|
41 |
+
top_p: Optional[float] = None,
|
42 |
+
top_k: Optional[int] = None,
|
43 |
+
repetition_penalty: Optional[float] = None,
|
44 |
+
stop_sequences: Optional[str] = None,
|
45 |
+
process_prompt: Optional[bool] = False,
|
46 |
+
sambastudio_base_url: Optional[str] = None,
|
47 |
+
sambastudio_base_uri: Optional[str] = None,
|
48 |
+
sambastudio_project_id: Optional[str] = None,
|
49 |
+
sambastudio_endpoint_id: Optional[str] = None,
|
50 |
+
sambastudio_api_key: Optional[str] = None,
|
51 |
+
sambanova_url: Optional[str] = None,
|
52 |
+
sambanova_api_key: Optional[str] = None,
|
53 |
+
) -> LLM:
|
54 |
+
"""Loads a langchain Sambanova llm model given a type and parameters
|
55 |
+
Args:
|
56 |
+
type (str): wether to use sambastudio, or SambaNova Cloud model "sncloud"
|
57 |
+
streaming (bool): wether to use streaming method. Defaults to False.
|
58 |
+
coe (bool): whether to use coe model. Defaults to False.
|
59 |
+
|
60 |
+
do_sample (bool) : Optional wether to do sample.
|
61 |
+
max_tokens_to_generate (int) : Optional max number of tokens to generate.
|
62 |
+
temperature (float) : Optional model temperature.
|
63 |
+
select_expert (str) : Optional expert to use when using CoE models.
|
64 |
+
top_p (float) : Optional model top_p.
|
65 |
+
top_k (int) : Optional model top_k.
|
66 |
+
repetition_penalty (float) : Optional model repetition penalty.
|
67 |
+
stop_sequences (str) : Optional model stop sequences.
|
68 |
+
process_prompt (bool) : Optional default to false.
|
69 |
+
|
70 |
+
sambastudio_base_url (str): Optional SambaStudio environment URL".
|
71 |
+
sambastudio_base_uri (str): Optional SambaStudio-base-URI".
|
72 |
+
sambastudio_project_id (str): Optional SambaStudio project ID.
|
73 |
+
sambastudio_endpoint_id (str): Optional SambaStudio endpoint ID.
|
74 |
+
sambastudio_api_token (str): Optional SambaStudio endpoint API key.
|
75 |
+
|
76 |
+
sambanova_url (str): Optional SambaNova Cloud URL",
|
77 |
+
sambanova_api_key (str): Optional SambaNovaCloud API key.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
langchain llm model
|
81 |
+
"""
|
82 |
+
|
83 |
+
if type == 'sambastudio':
|
84 |
+
envs = {
|
85 |
+
'sambastudio_base_url': sambastudio_base_url,
|
86 |
+
'sambastudio_base_uri': sambastudio_base_uri,
|
87 |
+
'sambastudio_project_id': sambastudio_project_id,
|
88 |
+
'sambastudio_endpoint_id': sambastudio_endpoint_id,
|
89 |
+
'sambastudio_api_key': sambastudio_api_key,
|
90 |
+
}
|
91 |
+
envs = {k: v for k, v in envs.items() if v is not None}
|
92 |
+
if coe:
|
93 |
+
model_kwargs = {
|
94 |
+
'do_sample': do_sample,
|
95 |
+
'max_tokens_to_generate': max_tokens_to_generate,
|
96 |
+
'temperature': temperature,
|
97 |
+
'select_expert': select_expert,
|
98 |
+
'top_p': top_p,
|
99 |
+
'top_k': top_k,
|
100 |
+
'repetition_penalty': repetition_penalty,
|
101 |
+
'stop_sequences': stop_sequences,
|
102 |
+
'process_prompt': process_prompt,
|
103 |
+
}
|
104 |
+
model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None}
|
105 |
+
|
106 |
+
llm = SambaStudio(
|
107 |
+
**envs,
|
108 |
+
streaming=streaming,
|
109 |
+
model_kwargs=model_kwargs,
|
110 |
+
)
|
111 |
+
else:
|
112 |
+
model_kwargs = {
|
113 |
+
'do_sample': do_sample,
|
114 |
+
'max_tokens_to_generate': max_tokens_to_generate,
|
115 |
+
'temperature': temperature,
|
116 |
+
'top_p': top_p,
|
117 |
+
'top_k': top_k,
|
118 |
+
'repetition_penalty': repetition_penalty,
|
119 |
+
'stop_sequences': stop_sequences,
|
120 |
+
}
|
121 |
+
model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None}
|
122 |
+
llm = SambaStudio(
|
123 |
+
**envs,
|
124 |
+
streaming=streaming,
|
125 |
+
model_kwargs=model_kwargs,
|
126 |
+
)
|
127 |
+
|
128 |
+
elif type == 'sncloud':
|
129 |
+
envs = {
|
130 |
+
'sambanova_url': sambanova_url,
|
131 |
+
'sambanova_api_key': sambanova_api_key,
|
132 |
+
}
|
133 |
+
envs = {k: v for k, v in envs.items() if v is not None}
|
134 |
+
llm = SambaNovaCloud(
|
135 |
+
**envs,
|
136 |
+
max_tokens=max_tokens_to_generate,
|
137 |
+
model=select_expert,
|
138 |
+
temperature=temperature,
|
139 |
+
top_k=top_k,
|
140 |
+
top_p=top_p,
|
141 |
+
)
|
142 |
+
|
143 |
+
else:
|
144 |
+
raise ValueError(f"Invalid LLM API: {type}, only 'sncloud' and 'sambastudio' are supported.")
|
145 |
+
|
146 |
+
return llm
|
toolformers/sambanova/core.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import os
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
from toolformers.base import Conversation, Toolformer, Tool
|
6 |
+
from toolformers.sambanova.function_calling import FunctionCallingLlm
|
7 |
+
|
8 |
+
class SambanovaConversation(Conversation):
|
9 |
+
def __init__(self, model_name, function_calling_llm : FunctionCallingLlm, category=None):
|
10 |
+
self.model_name = model_name
|
11 |
+
self.function_calling_llm = function_calling_llm
|
12 |
+
self.category = category
|
13 |
+
|
14 |
+
def chat(self, message, role='user', print_output=True):
|
15 |
+
if role != 'user':
|
16 |
+
raise ValueError('Role must be "user"')
|
17 |
+
|
18 |
+
agent_id = os.environ.get('AGENT_ID', None)
|
19 |
+
|
20 |
+
start_time = datetime.datetime.now()
|
21 |
+
|
22 |
+
response, usage_data = self.function_calling_llm.function_call_llm(message)
|
23 |
+
|
24 |
+
end_time = datetime.datetime.now()
|
25 |
+
|
26 |
+
print('Usage data:', usage_data)
|
27 |
+
if print_output:
|
28 |
+
print(response)
|
29 |
+
|
30 |
+
#send_usage_to_db(usage_data, start_time, end_time, agent_id, self.category, self.model_name)
|
31 |
+
|
32 |
+
return response
|
33 |
+
|
34 |
+
class SambanovaToolformer(Toolformer):
|
35 |
+
def __init__(self, model_name: str):
|
36 |
+
self.model_name = model_name
|
37 |
+
|
38 |
+
def new_conversation(self, system_prompt: str, tools: List[Tool], category=None) -> SambanovaConversation:
|
39 |
+
function_calling_llm = FunctionCallingLlm(system_prompt=system_prompt, tools=tools, select_expert=self.model_name)
|
40 |
+
return SambanovaConversation(self.model_name, function_calling_llm, category)
|
41 |
+
|
42 |
+
#def make_llama_toolformer(model_name, system_prompt: str, tools: List[Tool]):
|
43 |
+
# if model_name not in ['llama3-8b', 'llama3-70b', 'llama3-405b']:
|
44 |
+
# raise ValueError(f"Unknown model name: {model_name}")
|
45 |
+
#
|
46 |
+
# return SambanovaToolformer(model_name, system_prompt, tools)
|
toolformers/sambanova/function_calling.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
from random import random
|
5 |
+
from pprint import pprint
|
6 |
+
import time
|
7 |
+
from typing import List, Optional, Union
|
8 |
+
|
9 |
+
from langchain_core.messages.ai import AIMessage
|
10 |
+
from langchain_core.messages.human import HumanMessage
|
11 |
+
from langchain_core.messages.tool import ToolMessage
|
12 |
+
from langchain_core.prompts import ChatPromptTemplate
|
13 |
+
from langchain_core.runnables import RunnableLambda
|
14 |
+
|
15 |
+
from toolformers.base import Tool, StringParameter
|
16 |
+
from toolformers.sambanova.api_gateway import APIGateway
|
17 |
+
|
18 |
+
from toolformers.sambanova.utils import get_total_usage, usage_tracker
|
19 |
+
|
20 |
+
|
21 |
+
FUNCTION_CALLING_SYSTEM_PROMPT = """You have access to the following tools:
|
22 |
+
|
23 |
+
{tools}
|
24 |
+
|
25 |
+
You can call one or more tools by adding a <ToolCalls> section to your message. For example:
|
26 |
+
<ToolCalls>
|
27 |
+
```json
|
28 |
+
[{{
|
29 |
+
"tool": <name of the selected tool>,
|
30 |
+
"tool_input": <parameters for the selected tool, matching the tool's JSON schema>
|
31 |
+
}}]
|
32 |
+
```
|
33 |
+
</ToolCalls>
|
34 |
+
|
35 |
+
Note that you can select multiple tools at once by adding more objects to the list. Do not add \
|
36 |
+
multiple <ToolCalls> sections to the same message.
|
37 |
+
You will see the invocation of the tools in the response.
|
38 |
+
|
39 |
+
|
40 |
+
Think step by step
|
41 |
+
Do not call a tool if the input depends on another tool output that you do not have yet.
|
42 |
+
Do not try to answer until you get all the tools output, if you do not have an answer yet, you can continue calling tools until you do.
|
43 |
+
Your answer should be in the same language as the initial query.
|
44 |
+
|
45 |
+
""" # noqa E501
|
46 |
+
|
47 |
+
|
48 |
+
conversational_response = Tool(
|
49 |
+
name='ConversationalResponse',
|
50 |
+
description='Respond conversationally only if no other tools should be called for a given query, or if you have a final answer. Response must be in the same language as the user query.',
|
51 |
+
parameters=[StringParameter(name='response', description='Conversational response to the user. Must be in the same language as the user query.', required=True)],
|
52 |
+
function=None
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
class FunctionCallingLlm:
|
57 |
+
"""
|
58 |
+
function calling llm class
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
tools: Optional[Union[Tool, List[Tool]]] = None,
|
64 |
+
default_tool: Optional[Tool] = None,
|
65 |
+
system_prompt: Optional[str] = None,
|
66 |
+
prod_mode: bool = False,
|
67 |
+
api: str = 'sncloud',
|
68 |
+
coe: bool = False,
|
69 |
+
do_sample: bool = False,
|
70 |
+
max_tokens_to_generate: Optional[int] = None,
|
71 |
+
temperature: float = 0.2,
|
72 |
+
select_expert: Optional[str] = None,
|
73 |
+
) -> None:
|
74 |
+
"""
|
75 |
+
Args:
|
76 |
+
tools (Optional[Union[Tool, List[Tool]]]): The tools to use.
|
77 |
+
default_tool (Optional[Tool]): The default tool to use.
|
78 |
+
defaults to ConversationalResponse
|
79 |
+
system_prompt (Optional[str]): The system prompt to use. defaults to FUNCTION_CALLING_SYSTEM_PROMPT
|
80 |
+
prod_mode (bool): Whether to use production mode. Defaults to False.
|
81 |
+
api (str): The api to use. Defaults to 'sncloud'.
|
82 |
+
coe (bool): Whether to use coe. Defaults to False.
|
83 |
+
do_sample (bool): Whether to do sample. Defaults to False.
|
84 |
+
max_tokens_to_generate (Optional[int]): The max tokens to generate. If None, the model will attempt to use the maximum available tokens.
|
85 |
+
temperature (float): The model temperature. Defaults to 0.2.
|
86 |
+
select_expert (Optional[str]): The expert to use. Defaults to None.
|
87 |
+
"""
|
88 |
+
self.prod_mode = prod_mode
|
89 |
+
sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY')
|
90 |
+
self.api = api
|
91 |
+
self.llm = APIGateway.load_llm(
|
92 |
+
type=api,
|
93 |
+
streaming=True,
|
94 |
+
coe=coe,
|
95 |
+
do_sample=do_sample,
|
96 |
+
max_tokens_to_generate=max_tokens_to_generate,
|
97 |
+
temperature=temperature,
|
98 |
+
select_expert=select_expert,
|
99 |
+
process_prompt=False,
|
100 |
+
sambanova_api_key=sambanova_api_key,
|
101 |
+
)
|
102 |
+
|
103 |
+
if isinstance(tools, Tool):
|
104 |
+
tools = [tools]
|
105 |
+
self.tools = tools
|
106 |
+
if system_prompt is None:
|
107 |
+
system_prompt = ''
|
108 |
+
|
109 |
+
system_prompt = system_prompt.replace('{','{{').replace('}', '}}')
|
110 |
+
|
111 |
+
if len(self.tools) > 0:
|
112 |
+
system_prompt += '\n\n'
|
113 |
+
system_prompt += FUNCTION_CALLING_SYSTEM_PROMPT
|
114 |
+
self.system_prompt = system_prompt
|
115 |
+
|
116 |
+
if default_tool is None:
|
117 |
+
default_tool = conversational_response
|
118 |
+
|
119 |
+
def execute(self, invoked_tools: List[dict]) -> tuple[bool, List[str]]:
|
120 |
+
"""
|
121 |
+
Given a list of tool executions the llm return as required
|
122 |
+
execute them given the name with the mane in tools_map and the input arguments
|
123 |
+
if there is only one tool call and it is default conversational one, the response is marked as final response
|
124 |
+
|
125 |
+
Args:
|
126 |
+
invoked_tools (List[dict]): The list of tool executions generated by the LLM.
|
127 |
+
"""
|
128 |
+
if self.tools is not None:
|
129 |
+
tools_map = {tool.name.lower(): tool for tool in self.tools}
|
130 |
+
else:
|
131 |
+
tools_map = {}
|
132 |
+
tool_msg = "Tool '{name}' response: {response}"
|
133 |
+
tools_msgs = []
|
134 |
+
if len(invoked_tools) == 1 and invoked_tools[0]['tool'].lower() == 'conversationalresponse':
|
135 |
+
final_answer = True
|
136 |
+
return final_answer, [invoked_tools[0]['tool_input']['response']]
|
137 |
+
|
138 |
+
final_answer = False
|
139 |
+
|
140 |
+
for tool in invoked_tools:
|
141 |
+
if tool['tool'].lower() == 'invocationerror':
|
142 |
+
tools_msgs.append(f'Tool invocation error: {tool["tool_input"]}')
|
143 |
+
elif tool['tool'].lower() != 'conversationalresponse':
|
144 |
+
print(f"\n\n---\nTool {tool['tool'].lower()} invoked with input {tool['tool_input']}\n")
|
145 |
+
|
146 |
+
if tool['tool'].lower() not in tools_map:
|
147 |
+
tools_msgs.append(f'Tool {tool["tool"]} not found')
|
148 |
+
else:
|
149 |
+
response = tools_map[tool['tool'].lower()].call_tool_for_toolformer(**tool['tool_input'])
|
150 |
+
# print(f'Tool response: {str(response)}\n---\n\n')
|
151 |
+
tools_msgs.append(tool_msg.format(name=tool['tool'], response=str(response)))
|
152 |
+
return final_answer, tools_msgs
|
153 |
+
|
154 |
+
def json_finder(self, input_string: str) -> Optional[str]:
|
155 |
+
"""
|
156 |
+
find json structures in an LLM string response, if bad formatted using LLM to correct it
|
157 |
+
|
158 |
+
Args:
|
159 |
+
input_string (str): The string to find the json structure in.
|
160 |
+
"""
|
161 |
+
|
162 |
+
# 1. Ideal pattern: correctly surrounded by <ToolCalls> tags
|
163 |
+
json_pattern_1 = re.compile(r'<ToolCalls\>(.*)</ToolCalls\>', re.DOTALL + re.IGNORECASE)
|
164 |
+
# 2. Sometimes the closing tag is missing
|
165 |
+
json_pattern_2 = re.compile(r'<ToolCalls\>(.*)', re.DOTALL + re.IGNORECASE)
|
166 |
+
# 3. Sometimes it accidentally uses <ToolCall> instead of <ToolCalls>
|
167 |
+
json_pattern_3 = re.compile(r'<ToolCall\>(.*)</ToolCall\>', re.DOTALL + re.IGNORECASE)
|
168 |
+
# 4. Sometimes it accidentally uses <ToolCall> instead of <ToolCalls> and the closing tag is missing
|
169 |
+
json_pattern_4 = re.compile(r'<ToolCall\>(.*)', re.DOTALL + re.IGNORECASE)
|
170 |
+
|
171 |
+
# Find the first JSON structure in the string
|
172 |
+
json_match = json_pattern_1.search(input_string) or json_pattern_2.search(input_string) or json_pattern_3.search(input_string) or json_pattern_4.search(input_string)
|
173 |
+
if json_match:
|
174 |
+
json_str = json_match.group(1)
|
175 |
+
|
176 |
+
# 1. Outermost list of JSON object
|
177 |
+
call_pattern_1 = re.compile(r'\[.*\]', re.DOTALL)
|
178 |
+
# 2. Outermost JSON object
|
179 |
+
call_pattern_2 = re.compile(r'\{.*\}', re.DOTALL)
|
180 |
+
|
181 |
+
call_match_1 = call_pattern_1.search(json_str)
|
182 |
+
call_match_2 = call_pattern_2.search(json_str)
|
183 |
+
|
184 |
+
if call_match_1:
|
185 |
+
json_str = call_match_1.group(0)
|
186 |
+
try:
|
187 |
+
return json.loads(json_str)
|
188 |
+
except Exception as e:
|
189 |
+
return [{'tool': 'InvocationError', 'tool_input' : str(e)}]
|
190 |
+
elif call_match_2:
|
191 |
+
json_str = call_match_2.group(0)
|
192 |
+
try:
|
193 |
+
return [json.loads(json_str)]
|
194 |
+
except Exception as e:
|
195 |
+
return [{'tool': 'InvocationError', 'tool_input' : str(e)}]
|
196 |
+
else:
|
197 |
+
return [{'tool': 'InvocationError', 'tool_input' : 'Could not find JSON object in the <ToolCalls> section'}]
|
198 |
+
else:
|
199 |
+
dummy_json_response = [{'tool': 'ConversationalResponse', 'tool_input': {'response': input_string}}]
|
200 |
+
json_str = dummy_json_response
|
201 |
+
return json_str
|
202 |
+
|
203 |
+
def msgs_to_llama3_str(self, msgs: list) -> str:
|
204 |
+
"""
|
205 |
+
convert a list of langchain messages with roles to expected LLmana 3 input
|
206 |
+
|
207 |
+
Args:
|
208 |
+
msgs (list): The list of langchain messages.
|
209 |
+
"""
|
210 |
+
formatted_msgs = []
|
211 |
+
for msg in msgs:
|
212 |
+
if msg.type == 'system':
|
213 |
+
sys_placeholder = (
|
214 |
+
'<|begin_of_text|><|start_header_id|>system<|end_header_id|>system<|end_header_id|> {msg}'
|
215 |
+
)
|
216 |
+
formatted_msgs.append(sys_placeholder.format(msg=msg.content))
|
217 |
+
elif msg.type == 'human':
|
218 |
+
human_placeholder = '<|eot_id|><|start_header_id|>user<|end_header_id|>\nUser: {msg} <|eot_id|><|start_header_id|>assistant<|end_header_id|>\nAssistant:' # noqa E501
|
219 |
+
formatted_msgs.append(human_placeholder.format(msg=msg.content))
|
220 |
+
elif msg.type == 'ai':
|
221 |
+
assistant_placeholder = '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\nAssistant: {msg}'
|
222 |
+
formatted_msgs.append(assistant_placeholder.format(msg=msg.content))
|
223 |
+
elif msg.type == 'tool':
|
224 |
+
tool_placeholder = '<|eot_id|><|start_header_id|>tools<|end_header_id|>\n{msg} <|eot_id|><|start_header_id|>assistant<|end_header_id|>\nAssistant:' # noqa E501
|
225 |
+
formatted_msgs.append(tool_placeholder.format(msg=msg.content))
|
226 |
+
else:
|
227 |
+
raise ValueError(f'Invalid message type: {msg.type}')
|
228 |
+
return '\n'.join(formatted_msgs)
|
229 |
+
|
230 |
+
def msgs_to_sncloud(self, msgs: list) -> list:
|
231 |
+
"""
|
232 |
+
convert a list of langchain messages with roles to expected FastCoE input
|
233 |
+
|
234 |
+
Args:
|
235 |
+
msgs (list): The list of langchain messages.
|
236 |
+
"""
|
237 |
+
formatted_msgs = []
|
238 |
+
for msg in msgs:
|
239 |
+
if msg.type == 'system':
|
240 |
+
formatted_msgs.append({'role': 'system', 'content': msg.content})
|
241 |
+
elif msg.type == 'human':
|
242 |
+
formatted_msgs.append({'role': 'user', 'content': msg.content})
|
243 |
+
elif msg.type == 'ai':
|
244 |
+
formatted_msgs.append({'role': 'assistant', 'content': msg.content})
|
245 |
+
elif msg.type == 'tool':
|
246 |
+
formatted_msgs.append({'role': 'tools', 'content': msg.content})
|
247 |
+
else:
|
248 |
+
raise ValueError(f'Invalid message type: {msg.type}')
|
249 |
+
return json.dumps(formatted_msgs)
|
250 |
+
|
251 |
+
def function_call_llm(self, query: str, max_it: int = 10, debug: bool = False) -> str:
|
252 |
+
"""
|
253 |
+
invocation method for function calling workflow
|
254 |
+
|
255 |
+
Args:
|
256 |
+
query (str): The query to execute.
|
257 |
+
max_it (int, optional): The maximum number of iterations. Defaults to 5.
|
258 |
+
debug (bool, optional): Whether to print debug information. Defaults to False.
|
259 |
+
"""
|
260 |
+
function_calling_chat_template = ChatPromptTemplate.from_messages([('system', self.system_prompt)])
|
261 |
+
tools_schemas = [tool.as_llama_schema() for tool in self.tools]
|
262 |
+
|
263 |
+
history = function_calling_chat_template.format_prompt(tools=tools_schemas).to_messages()
|
264 |
+
|
265 |
+
history.append(HumanMessage(query))
|
266 |
+
tool_call_id = 0 # identification for each tool calling required to create ToolMessages
|
267 |
+
with usage_tracker():
|
268 |
+
|
269 |
+
for i in range(max_it):
|
270 |
+
json_parsing_chain = RunnableLambda(self.json_finder)
|
271 |
+
|
272 |
+
if self.api == 'sncloud':
|
273 |
+
prompt = self.msgs_to_sncloud(history)
|
274 |
+
else:
|
275 |
+
prompt = self.msgs_to_llama3_str(history)
|
276 |
+
# print(f'\n\n---\nCalling function calling LLM with prompt: \n{prompt}\n')
|
277 |
+
|
278 |
+
exponential_backoff_lower = 30
|
279 |
+
exponential_backoff_higher = 60
|
280 |
+
llm_response = None
|
281 |
+
for _ in range(5):
|
282 |
+
try:
|
283 |
+
llm_response = self.llm.invoke(prompt, stream_options={'include_usage': True})
|
284 |
+
break
|
285 |
+
except Exception as e:
|
286 |
+
if '429' in str(e):
|
287 |
+
print('Rate limit exceeded. Waiting with random exponential backoff.')
|
288 |
+
time.sleep(random() * (exponential_backoff_higher - exponential_backoff_lower) + exponential_backoff_lower)
|
289 |
+
exponential_backoff_lower *= 2
|
290 |
+
exponential_backoff_higher *= 2
|
291 |
+
else:
|
292 |
+
raise e
|
293 |
+
|
294 |
+
print('LLM response:', llm_response)
|
295 |
+
|
296 |
+
# print(f'\nFunction calling LLM response: \n{llm_response}\n---\n')
|
297 |
+
parsed_tools_llm_response = json_parsing_chain.invoke(llm_response)
|
298 |
+
|
299 |
+
history.append(AIMessage(llm_response))
|
300 |
+
final_answer, tools_msgs = self.execute(parsed_tools_llm_response)
|
301 |
+
if final_answer: # if response was marked as final response in execution
|
302 |
+
final_response = tools_msgs[0]
|
303 |
+
if debug:
|
304 |
+
print('\n\n---\nFinal function calling LLM history: \n')
|
305 |
+
pprint(f'{history}')
|
306 |
+
return final_response, get_total_usage()
|
307 |
+
else:
|
308 |
+
history.append(ToolMessage('\n'.join(tools_msgs), tool_call_id=tool_call_id))
|
309 |
+
tool_call_id += 1
|
310 |
+
|
311 |
+
|
312 |
+
raise Exception('Not a final response yet', history)
|
toolformers/sambanova/sambanova_langchain.py
ADDED
@@ -0,0 +1,827 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Langchain Wrapper around Sambanova LLM APIs."""
|
2 |
+
|
3 |
+
import json
|
4 |
+
from typing import Any, Dict, Generator, Iterator, List, Optional, Union
|
5 |
+
|
6 |
+
import requests
|
7 |
+
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
8 |
+
from langchain_core.language_models.llms import LLM
|
9 |
+
from langchain_core.outputs import GenerationChunk
|
10 |
+
from langchain_core.pydantic_v1 import Extra
|
11 |
+
from langchain_core.utils import get_from_dict_or_env, pre_init
|
12 |
+
from langchain_core.runnables import RunnableConfig, ensure_config
|
13 |
+
from langchain_core.language_models.base import (
|
14 |
+
LanguageModelInput,
|
15 |
+
)
|
16 |
+
from toolformers.sambanova.utils import append_to_usage_tracker
|
17 |
+
|
18 |
+
|
19 |
+
class SSEndpointHandler:
|
20 |
+
"""
|
21 |
+
SambaNova Systems Interface for SambaStudio model endpoints.
|
22 |
+
|
23 |
+
:param str host_url: Base URL of the DaaS API service
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, host_url: str, api_base_uri: str):
|
27 |
+
"""
|
28 |
+
Initialize the SSEndpointHandler.
|
29 |
+
|
30 |
+
:param str host_url: Base URL of the DaaS API service
|
31 |
+
:param str api_base_uri: Base URI of the DaaS API service
|
32 |
+
"""
|
33 |
+
self.host_url = host_url
|
34 |
+
self.api_base_uri = api_base_uri
|
35 |
+
self.http_session = requests.Session()
|
36 |
+
|
37 |
+
def _process_response(self, response: requests.Response) -> Dict:
|
38 |
+
"""
|
39 |
+
Processes the API response and returns the resulting dict.
|
40 |
+
|
41 |
+
All resulting dicts, regardless of success or failure, will contain the
|
42 |
+
`status_code` key with the API response status code.
|
43 |
+
|
44 |
+
If the API returned an error, the resulting dict will contain the key
|
45 |
+
`detail` with the error message.
|
46 |
+
|
47 |
+
If the API call was successful, the resulting dict will contain the key
|
48 |
+
`data` with the response data.
|
49 |
+
|
50 |
+
:param requests.Response response: the response object to process
|
51 |
+
:return: the response dict
|
52 |
+
:type: dict
|
53 |
+
"""
|
54 |
+
result: Dict[str, Any] = {}
|
55 |
+
try:
|
56 |
+
result = response.json()
|
57 |
+
except Exception as e:
|
58 |
+
result['detail'] = str(e)
|
59 |
+
if 'status_code' not in result:
|
60 |
+
result['status_code'] = response.status_code
|
61 |
+
return result
|
62 |
+
|
63 |
+
def _process_streaming_response(
|
64 |
+
self,
|
65 |
+
response: requests.Response,
|
66 |
+
) -> Generator[Dict, None, None]:
|
67 |
+
"""Process the streaming response"""
|
68 |
+
if 'api/predict/nlp' in self.api_base_uri:
|
69 |
+
try:
|
70 |
+
import sseclient
|
71 |
+
except ImportError:
|
72 |
+
raise ImportError(
|
73 |
+
'could not import sseclient library' 'Please install it with `pip install sseclient-py`.'
|
74 |
+
)
|
75 |
+
client = sseclient.SSEClient(response)
|
76 |
+
close_conn = False
|
77 |
+
for event in client.events():
|
78 |
+
if event.event == 'error_event':
|
79 |
+
close_conn = True
|
80 |
+
chunk = {
|
81 |
+
'event': event.event,
|
82 |
+
'data': event.data,
|
83 |
+
'status_code': response.status_code,
|
84 |
+
}
|
85 |
+
yield chunk
|
86 |
+
if close_conn:
|
87 |
+
client.close()
|
88 |
+
elif 'api/v2/predict/generic' in self.api_base_uri or 'api/predict/generic' in self.api_base_uri:
|
89 |
+
try:
|
90 |
+
for line in response.iter_lines():
|
91 |
+
chunk = json.loads(line)
|
92 |
+
if 'status_code' not in chunk:
|
93 |
+
chunk['status_code'] = response.status_code
|
94 |
+
yield chunk
|
95 |
+
except Exception as e:
|
96 |
+
raise RuntimeError(f'Error processing streaming response: {e}')
|
97 |
+
else:
|
98 |
+
raise ValueError(f'handling of endpoint uri: {self.api_base_uri} not implemented')
|
99 |
+
|
100 |
+
def _get_full_url(self, path: str) -> str:
|
101 |
+
"""
|
102 |
+
Return the full API URL for a given path.
|
103 |
+
|
104 |
+
:param str path: the sub-path
|
105 |
+
:returns: the full API URL for the sub-path
|
106 |
+
:type: str
|
107 |
+
"""
|
108 |
+
return f'{self.host_url}/{self.api_base_uri}/{path}'
|
109 |
+
|
110 |
+
def nlp_predict(
|
111 |
+
self,
|
112 |
+
project: str,
|
113 |
+
endpoint: str,
|
114 |
+
key: str,
|
115 |
+
input: Union[List[str], str],
|
116 |
+
params: Optional[str] = '',
|
117 |
+
stream: bool = False,
|
118 |
+
) -> Dict:
|
119 |
+
"""
|
120 |
+
NLP predict using inline input string.
|
121 |
+
|
122 |
+
:param str project: Project ID in which the endpoint exists
|
123 |
+
:param str endpoint: Endpoint ID
|
124 |
+
:param str key: API Key
|
125 |
+
:param str input_str: Input string
|
126 |
+
:param str params: Input params string
|
127 |
+
:returns: Prediction results
|
128 |
+
:type: dict
|
129 |
+
"""
|
130 |
+
if isinstance(input, str):
|
131 |
+
input = [input]
|
132 |
+
if 'api/predict/nlp' in self.api_base_uri:
|
133 |
+
if params:
|
134 |
+
data = {'inputs': input, 'params': json.loads(params)}
|
135 |
+
else:
|
136 |
+
data = {'inputs': input}
|
137 |
+
elif 'api/v2/predict/generic' in self.api_base_uri:
|
138 |
+
items = [{'id': f'item{i}', 'value': item} for i, item in enumerate(input)]
|
139 |
+
if params:
|
140 |
+
data = {'items': items, 'params': json.loads(params)}
|
141 |
+
else:
|
142 |
+
data = {'items': items}
|
143 |
+
elif 'api/predict/generic' in self.api_base_uri:
|
144 |
+
if params:
|
145 |
+
data = {'instances': input, 'params': json.loads(params)}
|
146 |
+
else:
|
147 |
+
data = {'instances': input}
|
148 |
+
else:
|
149 |
+
raise ValueError(f'handling of endpoint uri: {self.api_base_uri} not implemented')
|
150 |
+
response = self.http_session.post(
|
151 |
+
self._get_full_url(f'{project}/{endpoint}'),
|
152 |
+
headers={'key': key},
|
153 |
+
json=data,
|
154 |
+
)
|
155 |
+
return self._process_response(response)
|
156 |
+
|
157 |
+
def nlp_predict_stream(
|
158 |
+
self,
|
159 |
+
project: str,
|
160 |
+
endpoint: str,
|
161 |
+
key: str,
|
162 |
+
input: Union[List[str], str],
|
163 |
+
params: Optional[str] = '',
|
164 |
+
) -> Iterator[Dict]:
|
165 |
+
"""
|
166 |
+
NLP predict using inline input string.
|
167 |
+
|
168 |
+
:param str project: Project ID in which the endpoint exists
|
169 |
+
:param str endpoint: Endpoint ID
|
170 |
+
:param str key: API Key
|
171 |
+
:param str input_str: Input string
|
172 |
+
:param str params: Input params string
|
173 |
+
:returns: Prediction results
|
174 |
+
:type: dict
|
175 |
+
"""
|
176 |
+
if 'api/predict/nlp' in self.api_base_uri:
|
177 |
+
if isinstance(input, str):
|
178 |
+
input = [input]
|
179 |
+
if params:
|
180 |
+
data = {'inputs': input, 'params': json.loads(params)}
|
181 |
+
else:
|
182 |
+
data = {'inputs': input}
|
183 |
+
elif 'api/v2/predict/generic' in self.api_base_uri:
|
184 |
+
if isinstance(input, str):
|
185 |
+
input = [input]
|
186 |
+
items = [{'id': f'item{i}', 'value': item} for i, item in enumerate(input)]
|
187 |
+
if params:
|
188 |
+
data = {'items': items, 'params': json.loads(params)}
|
189 |
+
else:
|
190 |
+
data = {'items': items}
|
191 |
+
elif 'api/predict/generic' in self.api_base_uri:
|
192 |
+
if isinstance(input, list):
|
193 |
+
input = input[0]
|
194 |
+
if params:
|
195 |
+
data = {'instance': input, 'params': json.loads(params)}
|
196 |
+
else:
|
197 |
+
data = {'instance': input}
|
198 |
+
else:
|
199 |
+
raise ValueError(f'handling of endpoint uri: {self.api_base_uri} not implemented')
|
200 |
+
# Streaming output
|
201 |
+
response = self.http_session.post(
|
202 |
+
self._get_full_url(f'stream/{project}/{endpoint}'),
|
203 |
+
headers={'key': key},
|
204 |
+
json=data,
|
205 |
+
stream=True,
|
206 |
+
)
|
207 |
+
for chunk in self._process_streaming_response(response):
|
208 |
+
yield chunk
|
209 |
+
|
210 |
+
|
211 |
+
class SambaStudio(LLM):
|
212 |
+
"""
|
213 |
+
SambaStudio large language models.
|
214 |
+
|
215 |
+
To use, you should have the environment variables
|
216 |
+
``SAMBASTUDIO_BASE_URL`` set with your SambaStudio environment URL.
|
217 |
+
``SAMBASTUDIO_BASE_URI`` set with your SambaStudio api base URI.
|
218 |
+
``SAMBASTUDIO_PROJECT_ID`` set with your SambaStudio project ID.
|
219 |
+
``SAMBASTUDIO_ENDPOINT_ID`` set with your SambaStudio endpoint ID.
|
220 |
+
``SAMBASTUDIO_API_KEY`` set with your SambaStudio endpoint API key.
|
221 |
+
|
222 |
+
https://sambanova.ai/products/enterprise-ai-platform-sambanova-suite
|
223 |
+
|
224 |
+
read extra documentation in https://docs.sambanova.ai/sambastudio/latest/index.html
|
225 |
+
|
226 |
+
Example:
|
227 |
+
.. code-block:: python
|
228 |
+
|
229 |
+
from langchain_community.llms.sambanova import SambaStudio
|
230 |
+
SambaStudio(
|
231 |
+
sambastudio_base_url="your-SambaStudio-environment-URL",
|
232 |
+
sambastudio_base_uri="your-SambaStudio-base-URI",
|
233 |
+
sambastudio_project_id="your-SambaStudio-project-ID",
|
234 |
+
sambastudio_endpoint_id="your-SambaStudio-endpoint-ID",
|
235 |
+
sambastudio_api_key="your-SambaStudio-endpoint-API-key,
|
236 |
+
streaming=False
|
237 |
+
model_kwargs={
|
238 |
+
"do_sample": False,
|
239 |
+
"max_tokens_to_generate": 1000,
|
240 |
+
"temperature": 0.7,
|
241 |
+
"top_p": 1.0,
|
242 |
+
"repetition_penalty": 1,
|
243 |
+
"top_k": 50,
|
244 |
+
#"process_prompt": False,
|
245 |
+
#"select_expert": "Meta-Llama-3-8B-Instruct"
|
246 |
+
},
|
247 |
+
)
|
248 |
+
"""
|
249 |
+
|
250 |
+
sambastudio_base_url: str = ''
|
251 |
+
"""Base url to use"""
|
252 |
+
|
253 |
+
sambastudio_base_uri: str = ''
|
254 |
+
"""endpoint base uri"""
|
255 |
+
|
256 |
+
sambastudio_project_id: str = ''
|
257 |
+
"""Project id on sambastudio for model"""
|
258 |
+
|
259 |
+
sambastudio_endpoint_id: str = ''
|
260 |
+
"""endpoint id on sambastudio for model"""
|
261 |
+
|
262 |
+
sambastudio_api_key: str = ''
|
263 |
+
"""sambastudio api key"""
|
264 |
+
|
265 |
+
model_kwargs: Optional[dict] = None
|
266 |
+
"""Key word arguments to pass to the model."""
|
267 |
+
|
268 |
+
streaming: Optional[bool] = False
|
269 |
+
"""Streaming flag to get streamed response."""
|
270 |
+
|
271 |
+
class Config:
|
272 |
+
"""Configuration for this pydantic object."""
|
273 |
+
|
274 |
+
extra = 'forbid'#Extra.forbid
|
275 |
+
|
276 |
+
@classmethod
|
277 |
+
def is_lc_serializable(cls) -> bool:
|
278 |
+
return True
|
279 |
+
|
280 |
+
@property
|
281 |
+
def _identifying_params(self) -> Dict[str, Any]:
|
282 |
+
"""Get the identifying parameters."""
|
283 |
+
return {**{'model_kwargs': self.model_kwargs}}
|
284 |
+
|
285 |
+
@property
|
286 |
+
def _llm_type(self) -> str:
|
287 |
+
"""Return type of llm."""
|
288 |
+
return 'Sambastudio LLM'
|
289 |
+
|
290 |
+
@pre_init
|
291 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
292 |
+
"""Validate that api key and python package exists in environment."""
|
293 |
+
values['sambastudio_base_url'] = get_from_dict_or_env(values, 'sambastudio_base_url', 'SAMBASTUDIO_BASE_URL')
|
294 |
+
values['sambastudio_base_uri'] = get_from_dict_or_env(
|
295 |
+
values,
|
296 |
+
'sambastudio_base_uri',
|
297 |
+
'SAMBASTUDIO_BASE_URI',
|
298 |
+
default='api/predict/generic',
|
299 |
+
)
|
300 |
+
values['sambastudio_project_id'] = get_from_dict_or_env(
|
301 |
+
values, 'sambastudio_project_id', 'SAMBASTUDIO_PROJECT_ID'
|
302 |
+
)
|
303 |
+
values['sambastudio_endpoint_id'] = get_from_dict_or_env(
|
304 |
+
values, 'sambastudio_endpoint_id', 'SAMBASTUDIO_ENDPOINT_ID'
|
305 |
+
)
|
306 |
+
values['sambastudio_api_key'] = get_from_dict_or_env(values, 'sambastudio_api_key', 'SAMBASTUDIO_API_KEY')
|
307 |
+
return values
|
308 |
+
|
309 |
+
def _get_tuning_params(self, stop: Optional[List[str]]) -> str:
|
310 |
+
"""
|
311 |
+
Get the tuning parameters to use when calling the LLM.
|
312 |
+
|
313 |
+
Args:
|
314 |
+
stop: Stop words to use when generating. Model output is cut off at the
|
315 |
+
first occurrence of any of the stop substrings.
|
316 |
+
|
317 |
+
Returns:
|
318 |
+
The tuning parameters as a JSON string.
|
319 |
+
"""
|
320 |
+
_model_kwargs = self.model_kwargs or {}
|
321 |
+
_kwarg_stop_sequences = _model_kwargs.get('stop_sequences', [])
|
322 |
+
_stop_sequences = stop or _kwarg_stop_sequences
|
323 |
+
# if not _kwarg_stop_sequences:
|
324 |
+
# _model_kwargs["stop_sequences"] = ",".join(
|
325 |
+
# f'"{x}"' for x in _stop_sequences
|
326 |
+
# )
|
327 |
+
if 'api/v2/predict/generic' in self.sambastudio_base_uri:
|
328 |
+
tuning_params_dict = _model_kwargs
|
329 |
+
else:
|
330 |
+
tuning_params_dict = {k: {'type': type(v).__name__, 'value': str(v)} for k, v in (_model_kwargs.items())}
|
331 |
+
# _model_kwargs["stop_sequences"] = _kwarg_stop_sequences
|
332 |
+
tuning_params = json.dumps(tuning_params_dict)
|
333 |
+
return tuning_params
|
334 |
+
|
335 |
+
def _handle_nlp_predict(self, sdk: SSEndpointHandler, prompt: Union[List[str], str], tuning_params: str) -> str:
|
336 |
+
"""
|
337 |
+
Perform an NLP prediction using the SambaStudio endpoint handler.
|
338 |
+
|
339 |
+
Args:
|
340 |
+
sdk: The SSEndpointHandler to use for the prediction.
|
341 |
+
prompt: The prompt to use for the prediction.
|
342 |
+
tuning_params: The tuning parameters to use for the prediction.
|
343 |
+
|
344 |
+
Returns:
|
345 |
+
The prediction result.
|
346 |
+
|
347 |
+
Raises:
|
348 |
+
ValueError: If the prediction fails.
|
349 |
+
"""
|
350 |
+
response = sdk.nlp_predict(
|
351 |
+
self.sambastudio_project_id,
|
352 |
+
self.sambastudio_endpoint_id,
|
353 |
+
self.sambastudio_api_key,
|
354 |
+
prompt,
|
355 |
+
tuning_params,
|
356 |
+
)
|
357 |
+
if response['status_code'] != 200:
|
358 |
+
optional_detail = response.get('detail')
|
359 |
+
if optional_detail:
|
360 |
+
raise RuntimeError(
|
361 |
+
f"Sambanova /complete call failed with status code "
|
362 |
+
f"{response['status_code']}.\n Details: {optional_detail}"
|
363 |
+
)
|
364 |
+
else:
|
365 |
+
raise RuntimeError(
|
366 |
+
f"Sambanova /complete call failed with status code "
|
367 |
+
f"{response['status_code']}.\n response {response}"
|
368 |
+
)
|
369 |
+
if 'api/predict/nlp' in self.sambastudio_base_uri:
|
370 |
+
return response['data'][0]['completion']
|
371 |
+
elif 'api/v2/predict/generic' in self.sambastudio_base_uri:
|
372 |
+
return response['items'][0]['value']['completion']
|
373 |
+
elif 'api/predict/generic' in self.sambastudio_base_uri:
|
374 |
+
return response['predictions'][0]['completion']
|
375 |
+
else:
|
376 |
+
raise ValueError(f'handling of endpoint uri: {self.sambastudio_base_uri} not implemented')
|
377 |
+
|
378 |
+
def _handle_completion_requests(self, prompt: Union[List[str], str], stop: Optional[List[str]]) -> str:
|
379 |
+
"""
|
380 |
+
Perform a prediction using the SambaStudio endpoint handler.
|
381 |
+
|
382 |
+
Args:
|
383 |
+
prompt: The prompt to use for the prediction.
|
384 |
+
stop: stop sequences.
|
385 |
+
|
386 |
+
Returns:
|
387 |
+
The prediction result.
|
388 |
+
|
389 |
+
Raises:
|
390 |
+
ValueError: If the prediction fails.
|
391 |
+
"""
|
392 |
+
ss_endpoint = SSEndpointHandler(self.sambastudio_base_url, self.sambastudio_base_uri)
|
393 |
+
tuning_params = self._get_tuning_params(stop)
|
394 |
+
return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params)
|
395 |
+
|
396 |
+
def _handle_nlp_predict_stream(
|
397 |
+
self, sdk: SSEndpointHandler, prompt: Union[List[str], str], tuning_params: str
|
398 |
+
) -> Iterator[GenerationChunk]:
|
399 |
+
"""
|
400 |
+
Perform a streaming request to the LLM.
|
401 |
+
|
402 |
+
Args:
|
403 |
+
sdk: The SVEndpointHandler to use for the prediction.
|
404 |
+
prompt: The prompt to use for the prediction.
|
405 |
+
tuning_params: The tuning parameters to use for the prediction.
|
406 |
+
|
407 |
+
Returns:
|
408 |
+
An iterator of GenerationChunks.
|
409 |
+
"""
|
410 |
+
for chunk in sdk.nlp_predict_stream(
|
411 |
+
self.sambastudio_project_id,
|
412 |
+
self.sambastudio_endpoint_id,
|
413 |
+
self.sambastudio_api_key,
|
414 |
+
prompt,
|
415 |
+
tuning_params,
|
416 |
+
):
|
417 |
+
if chunk['status_code'] != 200:
|
418 |
+
error = chunk.get('error')
|
419 |
+
if error:
|
420 |
+
optional_code = error.get('code')
|
421 |
+
optional_details = error.get('details')
|
422 |
+
optional_message = error.get('message')
|
423 |
+
raise ValueError(
|
424 |
+
f"Sambanova /complete call failed with status code "
|
425 |
+
f"{chunk['status_code']}.\n"
|
426 |
+
f"Message: {optional_message}\n"
|
427 |
+
f"Details: {optional_details}\n"
|
428 |
+
f"Code: {optional_code}\n"
|
429 |
+
)
|
430 |
+
else:
|
431 |
+
raise RuntimeError(
|
432 |
+
f"Sambanova /complete call failed with status code " f"{chunk['status_code']}." f"{chunk}."
|
433 |
+
)
|
434 |
+
if 'api/predict/nlp' in self.sambastudio_base_uri:
|
435 |
+
text = json.loads(chunk['data'])['stream_token']
|
436 |
+
elif 'api/v2/predict/generic' in self.sambastudio_base_uri:
|
437 |
+
text = chunk['result']['items'][0]['value']['stream_token']
|
438 |
+
elif 'api/predict/generic' in self.sambastudio_base_uri:
|
439 |
+
if len(chunk['result']['responses']) > 0:
|
440 |
+
text = chunk['result']['responses'][0]['stream_token']
|
441 |
+
else:
|
442 |
+
text = ''
|
443 |
+
else:
|
444 |
+
raise ValueError(f'handling of endpoint uri: {self.sambastudio_base_uri}' f'not implemented')
|
445 |
+
generated_chunk = GenerationChunk(text=text)
|
446 |
+
yield generated_chunk
|
447 |
+
|
448 |
+
def _stream(
|
449 |
+
self,
|
450 |
+
prompt: Union[List[str], str],
|
451 |
+
stop: Optional[List[str]] = None,
|
452 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
453 |
+
**kwargs: Any,
|
454 |
+
) -> Iterator[GenerationChunk]:
|
455 |
+
"""Call out to Sambanova's complete endpoint.
|
456 |
+
|
457 |
+
Args:
|
458 |
+
prompt: The prompt to pass into the model.
|
459 |
+
stop: Optional list of stop words to use when generating.
|
460 |
+
|
461 |
+
Returns:
|
462 |
+
The string generated by the model.
|
463 |
+
"""
|
464 |
+
ss_endpoint = SSEndpointHandler(self.sambastudio_base_url, self.sambastudio_base_uri)
|
465 |
+
tuning_params = self._get_tuning_params(stop)
|
466 |
+
try:
|
467 |
+
if self.streaming:
|
468 |
+
for chunk in self._handle_nlp_predict_stream(ss_endpoint, prompt, tuning_params):
|
469 |
+
if run_manager:
|
470 |
+
run_manager.on_llm_new_token(chunk.text)
|
471 |
+
yield chunk
|
472 |
+
else:
|
473 |
+
return
|
474 |
+
except Exception as e:
|
475 |
+
# Handle any errors raised by the inference endpoint
|
476 |
+
raise ValueError(f'Error raised by the inference endpoint: {e}') from e
|
477 |
+
|
478 |
+
def _handle_stream_request(
|
479 |
+
self,
|
480 |
+
prompt: Union[List[str], str],
|
481 |
+
stop: Optional[List[str]],
|
482 |
+
run_manager: Optional[CallbackManagerForLLMRun],
|
483 |
+
kwargs: Dict[str, Any],
|
484 |
+
) -> str:
|
485 |
+
"""
|
486 |
+
Perform a streaming request to the LLM.
|
487 |
+
|
488 |
+
Args:
|
489 |
+
prompt: The prompt to generate from.
|
490 |
+
stop: Stop words to use when generating. Model output is cut off at the
|
491 |
+
first occurrence of any of the stop substrings.
|
492 |
+
run_manager: Callback manager for the run.
|
493 |
+
**kwargs: Additional keyword arguments. directly passed
|
494 |
+
to the sambastudio model in API call.
|
495 |
+
|
496 |
+
Returns:
|
497 |
+
The model output as a string.
|
498 |
+
"""
|
499 |
+
completion = ''
|
500 |
+
for chunk in self._stream(prompt=prompt, stop=stop, run_manager=run_manager, **kwargs):
|
501 |
+
completion += chunk.text
|
502 |
+
return completion
|
503 |
+
|
504 |
+
def _call(
|
505 |
+
self,
|
506 |
+
prompt: Union[List[str], str],
|
507 |
+
stop: Optional[List[str]] = None,
|
508 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
509 |
+
**kwargs: Any,
|
510 |
+
) -> str:
|
511 |
+
"""Call out to Sambanova's complete endpoint.
|
512 |
+
|
513 |
+
Args:
|
514 |
+
prompt: The prompt to pass into the model.
|
515 |
+
stop: Optional list of stop words to use when generating.
|
516 |
+
|
517 |
+
Returns:
|
518 |
+
The string generated by the model.
|
519 |
+
"""
|
520 |
+
if stop is not None:
|
521 |
+
raise Exception('stop not implemented')
|
522 |
+
try:
|
523 |
+
if self.streaming:
|
524 |
+
return self._handle_stream_request(prompt, stop, run_manager, kwargs)
|
525 |
+
return self._handle_completion_requests(prompt, stop)
|
526 |
+
except Exception as e:
|
527 |
+
# Handle any errors raised by the inference endpoint
|
528 |
+
raise ValueError(f'Error raised by the inference endpoint: {e}') from e
|
529 |
+
|
530 |
+
|
531 |
+
class SambaNovaCloud(LLM):
|
532 |
+
"""
|
533 |
+
SambaNova Cloud large language models.
|
534 |
+
|
535 |
+
To use, you should have the environment variables
|
536 |
+
``SAMBANOVA_URL`` set with your SambaNova Cloud URL.
|
537 |
+
``SAMBANOVA_API_KEY`` set with your SambaNova Cloud API Key.
|
538 |
+
|
539 |
+
http://cloud.sambanova.ai/
|
540 |
+
|
541 |
+
Example:
|
542 |
+
.. code-block:: python
|
543 |
+
|
544 |
+
SambaNovaCloud(
|
545 |
+
sambanova_url = SambaNova cloud endpoint URL,
|
546 |
+
sambanova_api_key = set with your SambaNova cloud API key,
|
547 |
+
max_tokens = mas number of tokens to generate
|
548 |
+
stop_tokens = list of stop tokens
|
549 |
+
model = model name
|
550 |
+
)
|
551 |
+
"""
|
552 |
+
|
553 |
+
sambanova_url: str = ''
|
554 |
+
"""SambaNova Cloud Url"""
|
555 |
+
|
556 |
+
sambanova_api_key: str = ''
|
557 |
+
"""SambaNova Cloud api key"""
|
558 |
+
|
559 |
+
max_tokens: int = 4000
|
560 |
+
"""max tokens to generate"""
|
561 |
+
|
562 |
+
stop_tokens: list = ['<|eot_id|>']
|
563 |
+
"""Stop tokens"""
|
564 |
+
|
565 |
+
model: str = 'llama3-8b'
|
566 |
+
"""LLM model expert to use"""
|
567 |
+
|
568 |
+
temperature: float = 0.0
|
569 |
+
"""model temperature"""
|
570 |
+
|
571 |
+
top_p: float = 0.0
|
572 |
+
"""model top p"""
|
573 |
+
|
574 |
+
top_k: int = 1
|
575 |
+
"""model top k"""
|
576 |
+
|
577 |
+
stream_api: bool = True
|
578 |
+
"""use stream api"""
|
579 |
+
|
580 |
+
stream_options: dict = {'include_usage': True}
|
581 |
+
"""stream options, include usage to get generation metrics"""
|
582 |
+
|
583 |
+
class Config:
|
584 |
+
"""Configuration for this pydantic object."""
|
585 |
+
|
586 |
+
extra = 'forbid'#Extra.forbid
|
587 |
+
|
588 |
+
@classmethod
|
589 |
+
def is_lc_serializable(cls) -> bool:
|
590 |
+
return True
|
591 |
+
|
592 |
+
@property
|
593 |
+
def _identifying_params(self) -> Dict[str, Any]:
|
594 |
+
"""Get the identifying parameters."""
|
595 |
+
return {
|
596 |
+
'model': self.model,
|
597 |
+
'max_tokens': self.max_tokens,
|
598 |
+
'stop': self.stop_tokens,
|
599 |
+
'temperature': self.temperature,
|
600 |
+
'top_p': self.top_p,
|
601 |
+
'top_k': self.top_k,
|
602 |
+
}
|
603 |
+
|
604 |
+
def invoke(
|
605 |
+
self,
|
606 |
+
input: LanguageModelInput,
|
607 |
+
config: Optional[RunnableConfig] = None,
|
608 |
+
*,
|
609 |
+
stop: Optional[List[str]] = None,
|
610 |
+
**kwargs: Any,
|
611 |
+
) -> str:
|
612 |
+
config = ensure_config(config)
|
613 |
+
|
614 |
+
print('Invoking SambaNovaCloud with input:', input)
|
615 |
+
response = self.generate_prompt(
|
616 |
+
[self._convert_input(input)],
|
617 |
+
stop=stop,
|
618 |
+
callbacks=config.get("callbacks"),
|
619 |
+
tags=config.get("tags"),
|
620 |
+
metadata=config.get("metadata"),
|
621 |
+
run_name=config.get("run_name"),
|
622 |
+
run_id=config.pop("run_id", None),
|
623 |
+
**kwargs,
|
624 |
+
)
|
625 |
+
run_infos = response.run
|
626 |
+
|
627 |
+
if len(run_infos) > 1:
|
628 |
+
raise NotImplementedError('Multiple runs not supported')
|
629 |
+
|
630 |
+
run_id = run_infos[0].run_id
|
631 |
+
|
632 |
+
#print('Raw response:', response.run)
|
633 |
+
#print('Run ID:', run_id)
|
634 |
+
#print(USAGE_TRACKER)
|
635 |
+
#if run_id in USAGE_TRACKER:
|
636 |
+
# print('Usage:', USAGE_TRACKER[run_id])
|
637 |
+
#return response
|
638 |
+
return (
|
639 |
+
response
|
640 |
+
.generations[0][0]
|
641 |
+
.text
|
642 |
+
)
|
643 |
+
|
644 |
+
@property
|
645 |
+
def _llm_type(self) -> str:
|
646 |
+
"""Return type of llm."""
|
647 |
+
return 'SambaNova Cloud'
|
648 |
+
|
649 |
+
@pre_init
|
650 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
651 |
+
"""Validate that api key and python package exists in environment."""
|
652 |
+
values['sambanova_url'] = get_from_dict_or_env(
|
653 |
+
values, 'sambanova_url', 'SAMBANOVA_URL', default='https://api.sambanova.ai/v1/chat/completions'
|
654 |
+
)
|
655 |
+
values['sambanova_api_key'] = get_from_dict_or_env(values, 'sambanova_api_key', 'SAMBANOVA_API_KEY')
|
656 |
+
return values
|
657 |
+
|
658 |
+
def _handle_nlp_predict_stream(
|
659 |
+
self,
|
660 |
+
prompt: Union[List[str], str],
|
661 |
+
stop: List[str],
|
662 |
+
) -> Iterator[GenerationChunk]:
|
663 |
+
"""
|
664 |
+
Perform a streaming request to the LLM.
|
665 |
+
|
666 |
+
Args:
|
667 |
+
prompt: The prompt to use for the prediction.
|
668 |
+
stop: list of stop tokens
|
669 |
+
|
670 |
+
Returns:
|
671 |
+
An iterator of GenerationChunks.
|
672 |
+
"""
|
673 |
+
try:
|
674 |
+
import sseclient
|
675 |
+
except ImportError:
|
676 |
+
raise ImportError('could not import sseclient library' 'Please install it with `pip install sseclient-py`.')
|
677 |
+
try:
|
678 |
+
formatted_prompt = json.loads(prompt)
|
679 |
+
except:
|
680 |
+
formatted_prompt = [{'role': 'user', 'content': prompt}]
|
681 |
+
|
682 |
+
http_session = requests.Session()
|
683 |
+
if not stop:
|
684 |
+
stop = self.stop_tokens
|
685 |
+
data = {
|
686 |
+
'messages': formatted_prompt,
|
687 |
+
'max_tokens': self.max_tokens,
|
688 |
+
'stop': stop,
|
689 |
+
'model': self.model,
|
690 |
+
'temperature': self.temperature,
|
691 |
+
'top_p': self.top_p,
|
692 |
+
'top_k': self.top_k,
|
693 |
+
'stream': self.stream_api,
|
694 |
+
'stream_options': self.stream_options,
|
695 |
+
}
|
696 |
+
# Streaming output
|
697 |
+
response = http_session.post(
|
698 |
+
self.sambanova_url,
|
699 |
+
headers={'Authorization': f'Bearer {self.sambanova_api_key}', 'Content-Type': 'application/json'},
|
700 |
+
json=data,
|
701 |
+
stream=True,
|
702 |
+
)
|
703 |
+
|
704 |
+
client = sseclient.SSEClient(response)
|
705 |
+
close_conn = False
|
706 |
+
|
707 |
+
print('Response:', response)
|
708 |
+
|
709 |
+
if response.status_code != 200:
|
710 |
+
raise RuntimeError(
|
711 |
+
f'Sambanova /complete call failed with status code ' f'{response.status_code}.' f'{response.text}.'
|
712 |
+
)
|
713 |
+
|
714 |
+
for event in client.events():
|
715 |
+
if event.event == 'error_event':
|
716 |
+
close_conn = True
|
717 |
+
#print('Event:', event.data)
|
718 |
+
chunk = {
|
719 |
+
'event': event.event,
|
720 |
+
'data': event.data,
|
721 |
+
'status_code': response.status_code,
|
722 |
+
}
|
723 |
+
|
724 |
+
if chunk.get('error'):
|
725 |
+
raise RuntimeError(
|
726 |
+
f"Sambanova /complete call failed with status code " f"{chunk['status_code']}." f"{chunk}."
|
727 |
+
)
|
728 |
+
|
729 |
+
try:
|
730 |
+
# check if the response is a final event in that case event data response is '[DONE]'
|
731 |
+
#if 'usage' in chunk['data']:
|
732 |
+
# usage = json.loads(chunk['data'])
|
733 |
+
# print('Usage:', usage)
|
734 |
+
if chunk['data'] != '[DONE]':
|
735 |
+
data = json.loads(chunk['data'])
|
736 |
+
if data.get('error'):
|
737 |
+
raise RuntimeError(
|
738 |
+
f"Sambanova /complete call failed with status code " f"{chunk['status_code']}." f"{chunk}."
|
739 |
+
)
|
740 |
+
# check if the response is a final response with usage stats (not includes content)
|
741 |
+
if data.get('usage') is None:
|
742 |
+
# check is not "end of text" response
|
743 |
+
if data['choices'][0]['finish_reason'] is None:
|
744 |
+
text = data['choices'][0]['delta']['content']
|
745 |
+
generated_chunk = GenerationChunk(text=text)
|
746 |
+
yield generated_chunk
|
747 |
+
else:
|
748 |
+
#if data['id'] not in USAGE_TRACKER:
|
749 |
+
# USAGE_TRACKER[data['id']] = []
|
750 |
+
#USAGE_TRACKER[data['id']].append(data['usage'])
|
751 |
+
append_to_usage_tracker(data['usage'])
|
752 |
+
#print(f'Usage for id {data["id"]}:', data['usage'])
|
753 |
+
except Exception as e:
|
754 |
+
raise Exception(f'Error getting content chunk raw streamed response: {chunk}')
|
755 |
+
|
756 |
+
def _stream(
|
757 |
+
self,
|
758 |
+
prompt: Union[List[str], str],
|
759 |
+
stop: Optional[List[str]] = None,
|
760 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
761 |
+
**kwargs: Any,
|
762 |
+
) -> Iterator[GenerationChunk]:
|
763 |
+
"""Call out to Sambanova's complete endpoint.
|
764 |
+
|
765 |
+
Args:
|
766 |
+
prompt: The prompt to pass into the model.
|
767 |
+
stop: Optional list of stop words to use when generating.
|
768 |
+
|
769 |
+
Returns:
|
770 |
+
The string generated by the model.
|
771 |
+
"""
|
772 |
+
try:
|
773 |
+
for chunk in self._handle_nlp_predict_stream(prompt, stop):
|
774 |
+
if run_manager:
|
775 |
+
run_manager.on_llm_new_token(chunk.text)
|
776 |
+
yield chunk
|
777 |
+
except Exception as e:
|
778 |
+
# Handle any errors raised by the inference endpoint
|
779 |
+
raise ValueError(f'Error raised by the inference endpoint: {e}') from e
|
780 |
+
|
781 |
+
def _handle_stream_request(
|
782 |
+
self,
|
783 |
+
prompt: Union[List[str], str],
|
784 |
+
stop: Optional[List[str]],
|
785 |
+
run_manager: Optional[CallbackManagerForLLMRun],
|
786 |
+
kwargs: Dict[str, Any],
|
787 |
+
) -> str:
|
788 |
+
"""
|
789 |
+
Perform a streaming request to the LLM.
|
790 |
+
|
791 |
+
Args:
|
792 |
+
prompt: The prompt to generate from.
|
793 |
+
stop: Stop words to use when generating. Model output is cut off at the
|
794 |
+
first occurrence of any of the stop substrings.
|
795 |
+
run_manager: Callback manager for the run.
|
796 |
+
**kwargs: Additional keyword arguments. directly passed
|
797 |
+
to the Sambanova Cloud model in API call.
|
798 |
+
|
799 |
+
Returns:
|
800 |
+
The model output as a string.
|
801 |
+
"""
|
802 |
+
completion = ''
|
803 |
+
for chunk in self._stream(prompt=prompt, stop=stop, run_manager=run_manager, **kwargs):
|
804 |
+
completion += chunk.text
|
805 |
+
return completion
|
806 |
+
|
807 |
+
def _call(
|
808 |
+
self,
|
809 |
+
prompt: Union[List[str], str],
|
810 |
+
stop: Optional[List[str]] = None,
|
811 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
812 |
+
**kwargs: Any,
|
813 |
+
) -> str:
|
814 |
+
"""Call out to Sambanova's complete endpoint.
|
815 |
+
|
816 |
+
Args:
|
817 |
+
prompt: The prompt to pass into the model.
|
818 |
+
stop: Optional list of stop words to use when generating.
|
819 |
+
|
820 |
+
Returns:
|
821 |
+
The string generated by the model.
|
822 |
+
"""
|
823 |
+
try:
|
824 |
+
return self._handle_stream_request(prompt, stop, run_manager, kwargs)
|
825 |
+
except Exception as e:
|
826 |
+
# Handle any errors raised by the inference endpoint
|
827 |
+
raise ValueError(f'Error raised by the inference endpoint: {e}') from e
|
toolformers/sambanova/utils.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from contextlib import contextmanager
|
2 |
+
|
3 |
+
USAGE_TRACKER = None
|
4 |
+
|
5 |
+
@contextmanager
|
6 |
+
def usage_tracker():
|
7 |
+
global USAGE_TRACKER
|
8 |
+
assert USAGE_TRACKER is None
|
9 |
+
USAGE_TRACKER = []
|
10 |
+
try:
|
11 |
+
yield
|
12 |
+
finally:
|
13 |
+
USAGE_TRACKER = None
|
14 |
+
|
15 |
+
def get_total_usage():
|
16 |
+
global USAGE_TRACKER
|
17 |
+
|
18 |
+
prompt_tokens = 0
|
19 |
+
completion_tokens = 0
|
20 |
+
|
21 |
+
for usage in USAGE_TRACKER:
|
22 |
+
prompt_tokens += usage['prompt_tokens']
|
23 |
+
completion_tokens += usage['completion_tokens']
|
24 |
+
|
25 |
+
return {
|
26 |
+
'prompt_tokens': prompt_tokens,
|
27 |
+
'completion_tokens': completion_tokens
|
28 |
+
}
|
29 |
+
|
30 |
+
def append_to_usage_tracker(usage):
|
31 |
+
global USAGE_TRACKER
|
32 |
+
USAGE_TRACKER.append(usage)
|
utils.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import hashlib
|
3 |
+
|
4 |
+
from functools import wraps
|
5 |
+
from typing import Callable, Any
|
6 |
+
from inspect import signature, Parameter, Signature
|
7 |
+
|
8 |
+
|
9 |
+
def extract_substring(text, start_tag, end_tag):
|
10 |
+
start_position = text.lower().find(start_tag.lower())
|
11 |
+
end_position = text.lower().find(end_tag.lower())
|
12 |
+
|
13 |
+
if start_position == -1 or end_position == -1:
|
14 |
+
return None
|
15 |
+
|
16 |
+
return text[start_position + len(start_tag):end_position].strip()
|
17 |
+
|
18 |
+
def compute_hash(s):
|
19 |
+
# Hash a string using SHA-1 and return the base64 encoded result
|
20 |
+
|
21 |
+
m = hashlib.sha1()
|
22 |
+
m.update(s.encode())
|
23 |
+
|
24 |
+
b = m.digest()
|
25 |
+
|
26 |
+
return base64.b64encode(b).decode('ascii')
|
27 |
+
|
28 |
+
def add_params_and_annotations(name: str, description: str, params: dict, return_type: Any):
|
29 |
+
"""
|
30 |
+
A decorator to add parameters and a return type annotation to a function.
|
31 |
+
|
32 |
+
:param params: A dictionary where the keys are parameter names and values are their types.
|
33 |
+
:param return_type: The return type to add to the function's signature.
|
34 |
+
"""
|
35 |
+
def decorator(func: Callable):
|
36 |
+
# Create new parameters based on the provided params dict
|
37 |
+
new_params = [
|
38 |
+
Parameter(name, Parameter.POSITIONAL_OR_KEYWORD, annotation=type_)
|
39 |
+
for name, (type_, _) in params.items()
|
40 |
+
]
|
41 |
+
|
42 |
+
# Get the existing signature and parameters
|
43 |
+
original_sig = signature(func)
|
44 |
+
original_params = list(original_sig.parameters.values())
|
45 |
+
|
46 |
+
# Combine new parameters with the existing ones
|
47 |
+
combined_params = new_params# + original_params
|
48 |
+
|
49 |
+
# Create a new signature with updated parameters and return annotation
|
50 |
+
new_sig = Signature(parameters=combined_params, return_annotation=return_type)
|
51 |
+
|
52 |
+
# Define the wrapper function
|
53 |
+
@wraps(func)
|
54 |
+
def wrapper(*args, **kwargs):
|
55 |
+
return func(*args, **kwargs)
|
56 |
+
|
57 |
+
docstring = description
|
58 |
+
|
59 |
+
if len(params) > 0:
|
60 |
+
docstring += '\n\nArgs:'
|
61 |
+
for param_name, (type_, param_description) in params.items():
|
62 |
+
docstring += f'\n {param_name}: {param_description}'
|
63 |
+
|
64 |
+
docstring += f'\n\nReturns:\n {return_type.__name__}'
|
65 |
+
|
66 |
+
print('Input params:', params)
|
67 |
+
|
68 |
+
|
69 |
+
# Set the new signature on the wrapper
|
70 |
+
wrapper.__name__ = name
|
71 |
+
wrapper.__signature__ = new_sig
|
72 |
+
wrapper.__annotations__.update({ k: v[0] for k, v in params.items() })
|
73 |
+
wrapper.__annotations__['return'] = return_type
|
74 |
+
#wrapper.__annotations__['name'] = str
|
75 |
+
wrapper.__doc__ = docstring
|
76 |
+
|
77 |
+
print(docstring)
|
78 |
+
|
79 |
+
return wrapper
|
80 |
+
return decorator
|