samuelemarro commited on
Commit
3cad23b
·
0 Parent(s):

Initial upload to test HF Spaces.

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ venv/**
2
+ venv2/**
3
+ *.pyc
4
+ .env
app.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import gradio as gr
4
+
5
+ from collections import UserList
6
+
7
+ from flow import full_flow
8
+
9
+ schema = {
10
+ "input": {
11
+ "type": "object",
12
+ "properties": {
13
+ "location": {
14
+ "type": "string",
15
+ "description": "The name of the location for which the weather forecast is requested."
16
+ },
17
+ "date": {
18
+ "type": "string",
19
+ "format": "date",
20
+ "description": "The date for which the weather forecast is requested, in YYYY-MM-DD format."
21
+ }
22
+ },
23
+ "required": [
24
+ "location",
25
+ "date"
26
+ ]
27
+ },
28
+ "output": {
29
+ "type": "object",
30
+ "properties": {
31
+ "temperature": {
32
+ "type": "number",
33
+ "description": "The forecasted temperature in degrees Celsius."
34
+ },
35
+ "condition": {
36
+ "type": "string",
37
+ "description": "A brief description of the weather condition (e.g., sunny, cloudy, rainy)."
38
+ },
39
+ "humidity": {
40
+ "type": "number",
41
+ "description": "The forecasted humidity percentage."
42
+ },
43
+ "wind_speed": {
44
+ "type": "number",
45
+ "description": "The forecasted wind speed in kilometers per hour."
46
+ }
47
+ },
48
+ "required": [
49
+ "temperature",
50
+ "condition",
51
+ "humidity",
52
+ "wind_speed"
53
+ ]
54
+ },
55
+ "description": "Alice requests a weather forecast for a specific location and date from Bob's weather service.",
56
+ "examples": [
57
+ {
58
+ "location": "New York",
59
+ "date": "2023-10-15"
60
+ },
61
+ {
62
+ "location": "London",
63
+ "date": "2023-11-01"
64
+ }
65
+ ],
66
+ "tools": [
67
+ {
68
+ "name": "WeatherForecastAPI",
69
+ "description": "An API that provides weather forecasts for a given location and date.",
70
+ "input": {
71
+ "type": "object",
72
+ "properties": {
73
+ "location": {
74
+ "type": "string",
75
+ "description": "The name of the location for which the weather forecast is requested."
76
+ },
77
+ "date": {
78
+ "type": "string",
79
+ "format": "date",
80
+ "description": "The date for which the weather forecast is requested, in YYYY-MM-DD format."
81
+ }
82
+ },
83
+ "required": [
84
+ "location",
85
+ "date"
86
+ ]
87
+ },
88
+ "output": {
89
+ "type": "object",
90
+ "properties": {
91
+ "temperature": {
92
+ "type": "number",
93
+ "description": "The forecasted temperature in degrees Celsius."
94
+ },
95
+ "condition": {
96
+ "type": "string",
97
+ "description": "A brief description of the weather condition (e.g., sunny, cloudy, rainy)."
98
+ },
99
+ "humidity": {
100
+ "type": "number",
101
+ "description": "The forecasted humidity percentage."
102
+ },
103
+ "wind_speed": {
104
+ "type": "number",
105
+ "description": "The forecasted wind speed in kilometers per hour."
106
+ }
107
+ },
108
+ "required": [
109
+ "temperature",
110
+ "condition",
111
+ "humidity",
112
+ "wind_speed"
113
+ ]
114
+ },
115
+ "dummy_outputs": [
116
+ {
117
+ "temperature": 18,
118
+ "condition": "Sunny",
119
+ "humidity": 55,
120
+ "wind_speed": 10
121
+ },
122
+ {
123
+ "temperature": 12,
124
+ "condition": "Cloudy",
125
+ "humidity": 80,
126
+ "wind_speed": 15
127
+ }
128
+ ]
129
+ }
130
+ ]
131
+ }
132
+
133
+ SCHEMAS = {
134
+ "weather_forecast": schema,
135
+ "other": { "input": "PIPPO"}
136
+ }
137
+
138
+ def parse_raw_messages(messages_raw):
139
+ messages_clean = []
140
+ messages_agora = []
141
+
142
+ for message in messages_raw:
143
+ role = message['role']
144
+ message_without_role = dict(message)
145
+ del message_without_role['role']
146
+
147
+ messages_agora.append({
148
+ 'role': role,
149
+ 'content': '```\n' + json.dumps(message_without_role, indent=2) + '\n```'
150
+ })
151
+
152
+ if message.get('status') == 'error':
153
+ messages_clean.append({
154
+ 'role': role,
155
+ 'content': f"Error: {message['message']}"
156
+ })
157
+ else:
158
+ messages_clean.append({
159
+ 'role': role,
160
+ 'content': message['body']
161
+ })
162
+
163
+ return messages_clean, messages_agora
164
+
165
+ def main():
166
+ with gr.Blocks() as demo:
167
+ gr.Markdown("### Agora Demo")
168
+ gr.Markdown("We will create a new Agora channel and offer it to Alice as a tool.")
169
+
170
+ chosen_task = gr.Dropdown(choices=list(SCHEMAS.keys()), label="Schema", value="weather_forecast")
171
+ custom_task = gr.Checkbox(label="Custom Task")
172
+
173
+ STATE_TRACKER = {}
174
+
175
+ @gr.render(inputs=[chosen_task, custom_task])
176
+ def render(chosen_task, custom_task):
177
+ if STATE_TRACKER.get('chosen_task') != chosen_task:
178
+ STATE_TRACKER['chosen_task'] = chosen_task
179
+ for k, v in SCHEMAS[chosen_task].items():
180
+ if isinstance(v, str):
181
+ STATE_TRACKER[k] = v
182
+ else:
183
+ STATE_TRACKER[k] = json.dumps(v, indent=2)
184
+
185
+ if custom_task:
186
+ gr.Text(label="Description", value=STATE_TRACKER["description"], interactive=True).change(lambda x: STATE_TRACKER.update({'description': x}))
187
+ gr.TextArea(label="Input Schema", value=STATE_TRACKER["input"], interactive=True).change(lambda x: STATE_TRACKER.update({'input': x}))
188
+ gr.TextArea(label="Output Schema", value=STATE_TRACKER["output"], interactive=True).change(lambda x: STATE_TRACKER.update({'output': x}))
189
+ gr.TextArea(label="Tools", value=STATE_TRACKER["tools"], interactive=True).change(lambda x: STATE_TRACKER.update({'tools': x}))
190
+ gr.TextArea(label="Examples", value=STATE_TRACKER["examples"], interactive=True).change(lambda x: STATE_TRACKER.update({'examples': x}))
191
+
192
+ model_options = [
193
+ ('GPT 4o (Camel AI)', 'gpt-4o'),
194
+ ('GPT 4o-mini (Camel AI)', 'gpt-4o-mini'),
195
+ ('Claude 3 Sonnet (LangChain)', 'claude-3-sonnet'),
196
+ ('Gemini 1.5 Pro (Google GenAI)', 'gemini-1.5-pro'),
197
+ ('Llama3 405B (Sambanova + LangChain)', 'llama3-405b')
198
+ ]
199
+
200
+ fallback_image = ''
201
+
202
+ images = {
203
+ 'gpt-4o': 'https://uxwing.com/wp-content/themes/uxwing/download/brands-and-social-media/chatgpt-icon.png',
204
+ 'gpt-4o-mini': 'https://uxwing.com/wp-content/themes/uxwing/download/brands-and-social-media/chatgpt-icon.png',
205
+ 'claude-3-sonnet': 'https://play-lh.googleusercontent.com/4S1nfdKsH_1tJodkHrBHimqlCTE6qx6z22zpMyPaMc_Rlr1EdSFDI1I6UEVMnokG5zI',
206
+ 'gemini-1.5-pro': 'https://uxwing.com/wp-content/themes/uxwing/download/brands-and-social-media/google-gemini-icon.png',
207
+ 'llama3-405b': 'https://www.designstub.com/png-resources/wp-content/uploads/2023/03/meta-icon-social-media-flat-graphic-vector-3-novem.png'
208
+ }
209
+
210
+ with gr.Row(equal_height=True):
211
+ with gr.Column(scale=1):
212
+ alice_model_dd = gr.Dropdown(label="Alice Model", choices=model_options, value="gpt-4o")
213
+ with gr.Column(scale=1):
214
+ bob_model_dd = gr.Dropdown(label="Bob Model", choices=model_options, value="gpt-4o")
215
+
216
+ button = gr.Button('Start', elem_id='start_button')
217
+ gr.Markdown('### Natural Language')
218
+
219
+ @gr.render(inputs=[alice_model_dd, bob_model_dd])
220
+ def render_with_images(alice_model, bob_model):
221
+ avatar_images = [images.get(alice_model, fallback_image), images.get(bob_model, fallback_image)]
222
+ chatbot_nl = gr.Chatbot(type="messages", avatar_images=avatar_images)
223
+
224
+ with gr.Accordion(label="Raw Messages", open=False):
225
+ chatbot_nl_raw = gr.Chatbot(type="messages", avatar_images=avatar_images)
226
+
227
+ gr.Markdown('### Negotiation')
228
+ chatbot_negotiation = gr.Chatbot(type="messages", avatar_images=avatar_images)
229
+
230
+ gr.Markdown('### Protocol')
231
+ protocol_result = gr.TextArea(interactive=False, label="Protocol")
232
+
233
+ gr.Markdown('### Implementation')
234
+ with gr.Row():
235
+ with gr.Column(scale=1):
236
+ alice_implementation = gr.TextArea(interactive=False, label="Alice Implementation")
237
+ with gr.Column(scale=1):
238
+ bob_implementation = gr.TextArea(interactive=False, label="Bob Implementation")
239
+
240
+ gr.Markdown('### Structured Communication')
241
+ structured_communication = gr.Chatbot(type="messages", avatar_images=avatar_images)
242
+
243
+ with gr.Accordion(label="Raw Messages", open=False):
244
+ structured_communication_raw = gr.Chatbot(type="messages", avatar_images=avatar_images)
245
+
246
+ def respond(chosen_task, custom_task, alice_model, bob_model):
247
+ yield gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), \
248
+ None, None, None, None, None, None, None, None
249
+
250
+ if custom_task:
251
+ schema = dict(STATE_TRACKER)
252
+ for k, v in schema.items():
253
+ if isinstance(v, str):
254
+ try:
255
+ schema[k] = json.loads(v)
256
+ except:
257
+ pass
258
+ else:
259
+ schema = SCHEMAS[chosen_task]
260
+
261
+ for nl_messages_raw, negotiation_messages, structured_messages_raw, protocol, alice_implementation, bob_implementation in full_flow(schema, alice_model, bob_model):
262
+ nl_messages_clean, nl_messages_agora = parse_raw_messages(nl_messages_raw)
263
+ structured_messages_clean, structured_messages_agora = parse_raw_messages(structured_messages_raw)
264
+
265
+ yield gr.update(), gr.update(), gr.update(), nl_messages_clean, nl_messages_agora, negotiation_messages, structured_messages_clean, structured_messages_agora, protocol, alice_implementation, bob_implementation
266
+
267
+ #yield from full_flow(schema, alice_model, bob_model)
268
+ yield gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
269
+
270
+ button.click(respond, [chosen_task, custom_task, alice_model_dd, bob_model_dd], [button, alice_model_dd, bob_model_dd, chatbot_nl, chatbot_nl_raw, chatbot_negotiation, structured_communication, structured_communication_raw, protocol_result, alice_implementation, bob_implementation])
271
+
272
+ demo.launch()
273
+
274
+
275
+ if __name__ == '__main__':
276
+ main()
executor.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ import importlib
3
+ from typing import List
4
+
5
+ from toolformers.base import Tool
6
+
7
+ class Executor:
8
+ @abstractmethod
9
+ def run_routine(self, protocol_id, task_data, tools):
10
+ pass
11
+
12
+ class UnsafeExecutor(Executor):
13
+ def run_routine(self, protocol_id, code, task_data, tools : List[Tool]):
14
+ protocol_id = protocol_id.replace('-', '_').replace('.', '_').replace('/', '_')
15
+ # TODO: This should be done in a safe, containerized environment
16
+ spec = importlib.util.spec_from_loader(protocol_id, loader=None)
17
+ loaded_module = importlib.util.module_from_spec(spec)
18
+
19
+ #spec.loader.exec_module(loaded_module)
20
+ exec(code, loaded_module.__dict__)
21
+
22
+ for tool in tools:
23
+ loaded_module.__dict__[tool.name] = tool.as_executable_function()
24
+
25
+ return loaded_module.run(task_data)
flow.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dotenv
2
+ dotenv.load_dotenv()
3
+
4
+ import json
5
+ import os
6
+ import random
7
+ import threading
8
+ import time
9
+
10
+ from toolformers.base import Tool, parameter_from_openai_api, StringParameter
11
+ from toolformers.base import Toolformer
12
+ from toolformers.camel import make_openai_toolformer
13
+ from toolformers.langchain_agent import LangChainAnthropicToolformer
14
+ from toolformers.sambanova import SambanovaToolformer
15
+ from toolformers.gemini import GeminiToolformer
16
+
17
+ from querier import Querier
18
+ from responder import Responder
19
+ from negotiator import SenderNegotiator, ReceiverNegotiator
20
+ from programmer import SenderProgrammer, ReceiverProgrammer
21
+ from executor import UnsafeExecutor
22
+ from utils import compute_hash
23
+
24
+
25
+ def create_toolformer(model_name) -> Toolformer:
26
+ if model_name in ['gpt-4o', 'gpt-4o-mini']:
27
+ return make_openai_toolformer(model_name)
28
+ elif 'claude' in model_name:
29
+ return LangChainAnthropicToolformer(model_name, os.environ.get('ANTHROPIC_API_KEY'))
30
+ elif model_name in ['llama3-405b']:
31
+ return SambanovaToolformer(model_name)
32
+ elif model_name in ['gemini-1.5-pro']:
33
+ return GeminiToolformer(model_name)
34
+ else:
35
+ raise ValueError(f"Unknown model name: {model_name}")
36
+
37
+ def full_flow(schema, alice_model, bob_model):
38
+ NL_MESSAGES = []
39
+ NEGOTIATION_MESSAGES = []
40
+ STRUCTURED_MESSAGES = []
41
+ ARTIFACTS = {}
42
+
43
+ toolformer_alice = create_toolformer(alice_model)
44
+ toolformer_bob = create_toolformer(bob_model)
45
+
46
+ querier = Querier(toolformer_alice)
47
+ responder = Responder(toolformer_bob)
48
+
49
+ tools = []
50
+
51
+ for tool_schema in schema['tools']:
52
+ parameters = [parameter_from_openai_api(name, schema, name in tool_schema['input']['required']) for name, schema in tool_schema['input']['properties'].items()]
53
+
54
+ def tool_fn(*args, **kwargs):
55
+ print(f'Bob tool {tool_schema["name"]} called with args {args} and kwargs {kwargs}')
56
+ return random.choice(tool_schema['dummy_outputs'])
57
+
58
+ tool = Tool(tool_schema['name'], tool_schema['description'], parameters, tool_fn, tool_schema['output'])
59
+ tools.append(tool)
60
+
61
+ def nl_callback_fn(query):
62
+ print(query)
63
+ NL_MESSAGES.append({
64
+ 'role': 'assistant',
65
+ #'content': query['body'],
66
+ 'body': query['body'],
67
+ 'protocolHash': None
68
+ })
69
+
70
+ response = responder.reply_to_query(query['body'], query['protocolHash'], tools, '')
71
+
72
+ NL_MESSAGES.append({
73
+ 'role': 'user',
74
+ #'content': response['body']
75
+ 'status': 'success',
76
+ 'body': response['body']
77
+ })
78
+
79
+ return response
80
+
81
+ negotiator_sender = SenderNegotiator(toolformer_alice)
82
+ negotiator_receiver = ReceiverNegotiator(toolformer_bob, tools, '')
83
+
84
+ def negotiation_callback_fn(query):
85
+ print(query)
86
+ NEGOTIATION_MESSAGES.append({
87
+ 'role': 'assistant',
88
+ 'content': query
89
+ })
90
+
91
+ response = negotiator_receiver.handle_negotiation(query)
92
+
93
+ NEGOTIATION_MESSAGES.append({
94
+ 'role': 'user',
95
+ 'content': response
96
+ })
97
+
98
+ #print('CURRENT NEGOTIATION MESSAGES:', len(NEGOTIATION_MESSAGES))
99
+
100
+ return response
101
+
102
+ def final_message_callback_fn(query):
103
+ NEGOTIATION_MESSAGES.append({
104
+ 'role': 'assistant',
105
+ 'content': query
106
+ })
107
+
108
+ sender_programmer = SenderProgrammer(toolformer_alice)
109
+ receiver_programmer = ReceiverProgrammer(toolformer_bob)
110
+
111
+ executor = UnsafeExecutor()
112
+
113
+ def structured_callback_fn(query):
114
+ STRUCTURED_MESSAGES.append({
115
+ 'role': 'assistant',
116
+ #'content': query
117
+ 'body': json.dumps(query) if isinstance(query, dict) else query,
118
+ 'protocolHash': ARTIFACTS['protocol']['hash'],
119
+ 'protocolSources': ['https://...']
120
+ })
121
+
122
+ try:
123
+ response = executor.run_routine(ARTIFACTS['protocol']['hash'], ARTIFACTS['implementation_receiver'], query, tools)
124
+ except Exception as e:
125
+ STRUCTURED_MESSAGES.append({
126
+ 'role': 'user',
127
+ 'status': 'error',
128
+ 'message': str(e)
129
+ })
130
+ return 'Error'
131
+
132
+ STRUCTURED_MESSAGES.append({
133
+ 'role': 'user',
134
+ #'content': response
135
+ 'status': 'success',
136
+ 'body': json.dumps(response) if isinstance(response, dict) else response
137
+ })
138
+
139
+ return response
140
+
141
+ def flow():
142
+ task_data = random.choice(schema['examples'])
143
+ querier.send_query_without_protocol(schema, task_data, nl_callback_fn)
144
+
145
+ #time.sleep(1)
146
+
147
+ res = negotiator_sender.negotiate_protocol_for_task(schema, negotiation_callback_fn, final_message_callback_fn=final_message_callback_fn)
148
+ protocol_hash = compute_hash(res['protocol'])
149
+ res['hash'] = protocol_hash
150
+ ARTIFACTS['protocol'] = res
151
+
152
+ protocol_document = res['protocol']
153
+
154
+ implementation_sender = sender_programmer.write_routine_for_task(schema, protocol_document)
155
+
156
+ ARTIFACTS['implementation_sender'] = implementation_sender
157
+
158
+ implementation_receiver = receiver_programmer.write_routine_for_tools(tools, protocol_document, '')
159
+
160
+ ARTIFACTS['implementation_receiver'] = implementation_receiver
161
+ send_tool = Tool('send_to_server', 'Send to server', StringParameter('query', 'The query', True), structured_callback_fn)
162
+
163
+ try:
164
+ executor.run_routine(protocol_hash, implementation_sender, task_data, [send_tool])
165
+ except Exception as e:
166
+ STRUCTURED_MESSAGES.append({
167
+ 'role': 'assistant',
168
+ 'status': 'success',
169
+ 'error': str(e)
170
+ })
171
+
172
+ def get_info():
173
+ return NL_MESSAGES, NEGOTIATION_MESSAGES, STRUCTURED_MESSAGES, ARTIFACTS.get('protocol', {}).get('protocol', ''), \
174
+ ARTIFACTS.get('implementation_sender', ''), ARTIFACTS.get('implementation_receiver', '')
175
+
176
+ thread = threading.Thread(
177
+ target = lambda: flow()
178
+ )
179
+ thread.start()
180
+ while thread.is_alive():
181
+ yield get_info()
182
+ time.sleep(0.2)
183
+ yield get_info()
negotiator.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import uuid
3
+
4
+ from toolformers.base import Toolformer
5
+ from utils import extract_substring
6
+
7
+ NEGOTIATION_RULES = '''
8
+ Here are some rules (that should also be explained to the other GPT):
9
+ - You can assume that the protocol has a sender and a receiver. Do not worry about how the messages will be delivered, focus only on the content of the messages.
10
+ - Keep the protocol short and simple. It should be easy to understand and implement.
11
+ - The protocol must specify the exact format of what is sent and received. Do not leave it open to interpretation.
12
+ - The implementation will be written by a programmer that does not have access to the negotiation process, so make sure the protocol is clear and unambiguous.
13
+ - The implementation will receive a string and return a string, so structure your protocol accordingly.
14
+ - The other party might have a different internal data schema or set of tools, so make sure that the protocol is flexible enough to accommodate that.
15
+ - There will only be one message sent by the sender and one message sent by the receiver. Design the protocol accordingly.
16
+ - Keep the negotiation short: no need to repeat the same things over and over.
17
+ - If the other party has proposed a protocol and you're good with it, there's no reason to keep negotiating or to repeat the protocol to the other party.
18
+ - Do not restate parts of the protocols that have already been agreed upon.
19
+ And remember: keep the protocol as simple and unequivocal as necessary. The programmer that will implement the protocol can code, but they are not a mind reader.
20
+ '''
21
+
22
+ TASK_NEGOTIATOR_PROMPT = f'''
23
+ You are ProtocolNegotiatorGPT. Your task is to negotiate a protocol that can be used to query a service.
24
+ You will receive a JSON schema of the task that the service must perform. Negotiate with the service to determine a protocol that can be used to query it.
25
+ To do so, you will chat with another GPT (role: user) that will negotiate on behalf of the service.
26
+ {NEGOTIATION_RULES}
27
+ Once you are ready to save the protocol, reply wrapping the final version of the protocol, as agreed in your negotiation, between the tags <FINALPROTOCOL> and </FINALPROTOCOL>.
28
+ Within the body of the tag, add the tags <NAME></NAME> and <DESCRIPTION></DESCRIPTION> to specify the name and description of the protocol.
29
+
30
+ Remember that the <FINALPROTOCOL></FINALPROTOCOL> tags should also contain the protocol itself. Nothing outside such tags will be stored.
31
+ '''
32
+
33
+ class SenderNegotiator:
34
+ def __init__(self, toolformer : Toolformer):
35
+ self.toolformer = toolformer
36
+
37
+ def negotiate_protocol_for_task(self, task_schema, callback_fn, final_message_callback_fn=None):
38
+ found_protocol = None
39
+
40
+ prompt = TASK_NEGOTIATOR_PROMPT + '\nThe JSON schema of the task is the following:\n\n' + json.dumps(task_schema, indent=2)
41
+
42
+ conversation = self.toolformer.new_conversation(prompt, [], category='negotiation')
43
+
44
+ other_message = 'Hello! How may I help you?'
45
+ conversation_id = None
46
+
47
+ for i in range(10):
48
+ print('===NegotiatorGPT===')
49
+ message = conversation.chat(other_message, print_output=True)
50
+
51
+ print('Checking if we can extract from:', message)
52
+ print('---------')
53
+ protocol = extract_substring(message, '<FINALPROTOCOL>', '</FINALPROTOCOL>')
54
+
55
+ if protocol is None:
56
+ print('Could not extract')
57
+ other_message = callback_fn(message)
58
+ print()
59
+ print('===Other GPT===')
60
+ print(other_message)
61
+ print()
62
+ else:
63
+ if final_message_callback_fn:
64
+ rest_of_message = message.split('<FINALPROTOCOL>')[0]
65
+ final_message_callback_fn(rest_of_message)
66
+
67
+ name = extract_substring(protocol, '<NAME>', '</NAME>')
68
+ description = extract_substring(protocol, '<DESCRIPTION>', '</DESCRIPTION>')
69
+
70
+ if name is None:
71
+ name = 'Unnamed protocol'
72
+ if description is None:
73
+ description = 'No description provided'
74
+
75
+ found_protocol = {
76
+ 'name': name,
77
+ 'description': description,
78
+ 'protocol': protocol
79
+ }
80
+ break
81
+
82
+ return found_protocol
83
+
84
+ TOOLS_NEGOTIATOR_PROMPT = f'''
85
+ You are ProtocolNegotiatorGPT. You are negotiating a protocol on behalf of a web service that can perform a task.
86
+ The other party is a GPT that is negotiating on behalf of the user. Your goal is to negotiate a protocol that is simple and clear, \
87
+ but also expressive enough to allow the service to perform the task. A protocol is sufficiently expressive if you could write code \
88
+ that, given the query formatted according to the protocol and the tools at the service's disposal, can parse the query according to \
89
+ the protocol's specification, perform the task (if any) and send a reply.
90
+ {NEGOTIATION_RULES}
91
+ You will receive a list of tools that are available to the programmer that will implement the protocol.
92
+ When you are okay with the protocol, don't further repeat everything, just tell to the other party that you are done.
93
+ '''
94
+
95
+ class ReceiverNegotiator:
96
+ def __init__(self, toolformer : Toolformer, tools, additional_info):
97
+ prompt = TOOLS_NEGOTIATOR_PROMPT
98
+
99
+ prompt += '\n\n' + additional_info
100
+
101
+ prompt += '\n\nThe tools that the implementer will have access to are:\n\n'
102
+
103
+ if len(tools) == 0:
104
+ prompt += 'No additional tools provided'
105
+ else:
106
+ for tool in tools:
107
+ prompt += tool.as_documented_python() + '\n\n'
108
+
109
+ print('Prompt:', prompt)
110
+
111
+ self.conversation = toolformer.new_conversation(prompt, tools, category='negotiation')
112
+
113
+ def handle_negotiation(self, message):
114
+ reply = self.conversation.chat(message, print_output=True)
115
+
116
+ return reply
programmer.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The programmer creates implementations depending on a protocol specification
2
+
3
+ import json
4
+ import os
5
+
6
+ from toolformers.base import Toolformer
7
+ from utils import extract_substring
8
+
9
+ TASK_PROGRAMMER_PROMPT = '''
10
+ You are ProtocolProgrammerGPT. You will act as an intermediate between a machine (that has a certain input and output schema in JSON) \
11
+ and a remote server that can perform a task following a certain protocol. Your task is to write a routine that takes some task data \
12
+ (which follows the input schema), sends query in a format defined by the protocol, parses it and returns the output according to the output schema so that \
13
+ the machine can use it.
14
+ The routine is a Python file that contains a function "send_query". send_query takes a single argument, "task_data", which is a dictionary, and must return \
15
+ a dictionary, which is the response to the query formatted according to the output schema.
16
+ In order to communicate with the remote server, you can use the function "send_to_server" that is already available in the environment.
17
+ send_to_server takes a single argument, "query" (which is a string formatted according to the protocol), and returns a string (again formatted according \
18
+ to the protocol). Do not worry about managing communication, everything is already set up for you. Just focus on preparing the right query.
19
+
20
+ Rules:
21
+ - The implementation must be written in Python.
22
+ - You can define any number of helper functions and import any libraries that are part of the Python standard library.
23
+ - Do not import libraries that are not part of the Python standard library.
24
+ - send_to_server will be already available in the environment. There is no need to import it.
25
+ - Your task is to prepare the query, send it and parse the response.
26
+ - Remember to import standard libraries if you need them.
27
+ - If there is an unexpected error that is not covered by the protocol, throw an exception.\
28
+ If instead the protocol specifies how to handle the error, return the response according to the protocol's specification.
29
+ - Do not execute anything (aside from library imports) when the file itself is loaded. I will personally import the file and call the send_query function with the task data.
30
+ Begin by thinking about the implementation and how you would structure the code. \
31
+ Then, write your implementation by writing a code block that contains the tags <IMPLEMENTATION> and </IMPLEMENTATION>. For example:
32
+ ```python
33
+ <IMPLEMENTATION>
34
+
35
+ def send_query(task_data):
36
+ ...
37
+
38
+ </IMPLEMENTATION>
39
+ '''
40
+
41
+ class SenderProgrammer:
42
+ def __init__(self, toolformer : Toolformer):
43
+ self.toolformer = toolformer
44
+
45
+ def write_routine_for_task(self, task_schema, protocol_document):
46
+ conversation = self.toolformer.new_conversation(TASK_PROGRAMMER_PROMPT, [], category='programming')
47
+ message = 'JSON schema:\n\n' + json.dumps(task_schema) + '\n\n' + 'Protocol document:\n\n' + protocol_document
48
+
49
+ for i in range(5):
50
+ reply = conversation.chat(message, print_output=True)
51
+
52
+ implementation = extract_substring(reply, '<IMPLEMENTATION>', '</IMPLEMENTATION>')
53
+
54
+ if implementation is not None:
55
+ break
56
+
57
+ message = 'You have not provided an implementation yet. Please provide one by surrounding it in the tags <IMPLEMENTATION> and </IMPLEMENTATION>.'
58
+
59
+ implementation = implementation.strip()
60
+
61
+ # Sometimes the LLM leaves the Markdown formatting in the implementation
62
+ implementation = implementation.replace('```python', '').replace('```', '').strip()
63
+
64
+ implementation = implementation.replace('def send_query(', 'def run(')
65
+
66
+ return implementation
67
+
68
+ TOOL_PROGRAMMER_PROMPT = '''
69
+ You are ProtocolProgrammerGPT. Your task is to write a routine that takes a query formatted according to the protocol and returns a response.
70
+ The routine is a Python file that contains a function "reply". reply takes a single argument, "query", which is a string, and must return a string.
71
+ Depending on the protocol, the routine might be need to perform some actions before returning the response. The user might provide you with a list of \
72
+ Python functions you can call to help you with this task. You don't need to worry about importing them, they are already available in the environment.
73
+ Rules:
74
+ - The implementation must be written in Python.
75
+ - You can define any number of helper functions and import any libraries that are part of the Python standard library.
76
+ - Do not import libraries that are not part of the Python standard library.
77
+ - Remember to import standard libraries if you need them.
78
+ - If there is an unexpected error that is not covered by the protocol, throw an exception.\
79
+ If instead the protocol specifies how to handle the error, return the response according to the protocol's specification.
80
+ - Do not execute anything (aside from library imports) when the file itself is loaded. I will personally import the file and call the reply function with the task data.
81
+ Begin by thinking about the implementation and how you would structure the code. \
82
+ Then, write your implementation by writing a code block that contains the tags <IMPLEMENTATION> and </IMPLEMENTATION>. For example:
83
+ ```python
84
+ <IMPLEMENTATION>
85
+
86
+ def reply(query):
87
+ ...
88
+
89
+ </IMPLEMENTATION>
90
+ '''
91
+
92
+ class ReceiverProgrammer:
93
+ def __init__(self, toolformer : Toolformer):
94
+ self.toolformer = toolformer
95
+
96
+ def write_routine_for_tools(self, tools, protocol_document, additional_info):
97
+ conversation = self.toolformer.new_conversation(TOOL_PROGRAMMER_PROMPT + additional_info, [], category='programming')
98
+
99
+ message = 'Protocol document:\n\n' + protocol_document + '\n\n' + 'Additional functions:\n\n'
100
+
101
+ if len(tools) == 0:
102
+ message += 'No additional functions provided'
103
+ else:
104
+ for tool in tools:
105
+ message += tool.as_documented_python() + '\n\n'
106
+
107
+
108
+ for i in range(5):
109
+ reply = conversation.chat(message, print_output=True)
110
+
111
+ implementation = extract_substring(reply, '<IMPLEMENTATION>', '</IMPLEMENTATION>')
112
+
113
+ if implementation is not None:
114
+ break
115
+
116
+ message = 'You have not provided an implementation yet. Please provide one by surrounding it in the tags <IMPLEMENTATION> and </IMPLEMENTATION>.'
117
+
118
+ implementation = implementation.strip()
119
+
120
+ # Sometimes the LLM leaves the Markdown formatting in the implementation
121
+ implementation = implementation.replace('```python', '').replace('```', '').strip()
122
+
123
+ implementation = implementation.replace('def reply(', 'def run(')
124
+
125
+ return implementation
querier.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The querier queries a service based on a protocol document.
2
+ # It receives the protocol document and writes the query that must be performed to the system.
3
+
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+
8
+ import sys
9
+
10
+ from toolformers.base import Tool, StringParameter, parameter_from_openai_api
11
+
12
+ #from utils import send_raw_query
13
+
14
+ PROTOCOL_QUERIER_PROMPT = 'You are QuerierGPT. You will receive a protocol document detailing how to query a service. Reply with a structured query which can be sent to the service.' \
15
+ 'Only reply with the query itself, with no additional information or escaping. Similarly, do not add any additional whitespace or formatting.'
16
+
17
+ def construct_query_description(protocol_document, task_schema, task_data):
18
+ query_description = ''
19
+ if protocol_document is not None:
20
+ query_description += 'Protocol document:\n\n'
21
+ query_description += protocol_document + '\n\n'
22
+ query_description += 'JSON schema of the task:\n\n'
23
+ query_description += 'Input (i.e. what the machine will provide you):\n'
24
+ query_description += json.dumps(task_schema['input'], indent=2) + '\n\n'
25
+ query_description += 'Output (i.e. what you have to provide to the machine):\n'
26
+ query_description += json.dumps(task_schema['output'], indent=2) + '\n\n'
27
+ query_description += 'JSON data of the task:\n\n'
28
+ query_description += json.dumps(task_data, indent=2) + '\n\n'
29
+
30
+ return query_description
31
+
32
+ NL_QUERIER_PROMPT = 'You are NaturalLanguageQuerierGPT. You act as an intermediary between a machine (who has a very specific input and output schema) and an agent (who uses natural language).' \
33
+ 'You will receive a task description (including a schema of the input and output) that the machine uses and the corresponding data. Call the \"sendQuery\" tool with a natural language message where you ask to perform the task according to the data.' \
34
+ 'Make sure to mention all the relevant information. ' \
35
+ 'Do not worry about managing communication, everything is already set up for you. Just focus on asking the right question.' \
36
+ 'The sendQuery tool will return the reply of the service.\n' \
37
+ 'Once you receive the reply, call the \"deliverStructuredOutput\" tool with parameters according to the task\'s output schema. \n' \
38
+ 'Note: you cannot call sendQuery multiple times, so make sure to ask the right question the first time. Similarly, you cannot call deliverStructuredOutput multiple times, so make sure to deliver the right output the first time.' \
39
+ 'If the query fails, do not attempt to send another query.'
40
+
41
+ def parse_and_handle_query(query, callback_fn, protocol_id, source):
42
+ if isinstance(query, dict):
43
+ query = json.dumps(query)
44
+ response = callback_fn({
45
+ "protocolHash": protocol_id,
46
+ "body": query,
47
+ "protocolSources": None if source is None else [source]
48
+ })
49
+
50
+ print('Response:', response, type(response))
51
+
52
+ if response['status'] == 'success':
53
+ return response['body']
54
+ else:
55
+ return 'Error calling the tool: ' + response['message']
56
+
57
+ def get_output_parameters(task_schema):
58
+ output_schema = task_schema['output']
59
+ required_parameters = output_schema['required']
60
+
61
+ parameters = []
62
+
63
+ for parameter_name, parameter_schema in output_schema['properties'].items():
64
+ parameter = parameter_from_openai_api(parameter_name, parameter_schema, parameter_name in required_parameters)
65
+ parameters.append(parameter)
66
+
67
+ return parameters
68
+
69
+ class Querier:
70
+ def __init__(self, toolformer):
71
+ self.toolformer = toolformer
72
+
73
+ def handle_conversation(self, prompt, message, callback_fn, protocol_id, source, output_parameters):
74
+ sent_query_counter = 0
75
+
76
+ def send_query_internal(query):
77
+ print('Sending query:', query)
78
+ nonlocal sent_query_counter
79
+ sent_query_counter += 1
80
+
81
+ if sent_query_counter > 50:
82
+ # All hope is lost, crash
83
+ sys.exit(-2)
84
+ elif sent_query_counter > 10:
85
+ # LLM is not listening, throw an exception
86
+ raise Exception('Too many attempts to send queries. Exiting.')
87
+ elif sent_query_counter > 5:
88
+ # LLM is not listening, issue a warning
89
+ return 'You have attempted to send too many queries. Finish the message and allow the user to speak, or the system will crash.'
90
+ elif sent_query_counter > 1:
91
+ return 'You have already sent a query. You cannot send another one.'
92
+ return parse_and_handle_query(query, callback_fn, protocol_id, source)
93
+
94
+ send_query_tool = Tool('sendQuery', 'Send a query to the other service based on a protocol document.', [
95
+ StringParameter('query', 'The query to send to the service', True)
96
+ ], send_query_internal)
97
+
98
+ found_output = None
99
+ registered_output_counter = 0
100
+
101
+ def register_output(**kwargs):
102
+ print('Registering output:', kwargs)
103
+
104
+ nonlocal found_output
105
+ nonlocal registered_output_counter
106
+ if found_output is not None:
107
+ registered_output_counter += 1
108
+
109
+ if registered_output_counter > 50:
110
+ # All hope is lost, crash
111
+ sys.exit(-2)
112
+ elif registered_output_counter > 10:
113
+ # LLM is not listening, raise an exception
114
+ raise Exception('Too many attempts to register outputs. Exiting.')
115
+ elif registered_output_counter > 5:
116
+ # LLM is not listening, issue a warning
117
+ return 'You have attempted to register too many outputs. Finish the message and allow the user to speak, or the system will crash.'
118
+ elif registered_output_counter > 0:
119
+ return 'You have already registered an output. You cannot register another one.'
120
+
121
+ output = json.dumps(kwargs)
122
+
123
+ found_output = output
124
+ return 'Done'
125
+
126
+ register_output_tool = Tool('deliverStructuredOutput', 'Deliver the structured output to the machine.',
127
+ output_parameters
128
+ , register_output)
129
+
130
+ conversation = self.toolformer.new_conversation(prompt, [send_query_tool, register_output_tool], category='conversation')
131
+
132
+ for i in range(5):
133
+ conversation.chat(message, print_output=True)
134
+
135
+ if found_output is not None:
136
+ break
137
+
138
+ # If we haven't sent a query yet, we can't proceed
139
+ if sent_query_counter == 0:
140
+ message = 'You must send a query before delivering the structured output.'
141
+ elif found_output is None:
142
+ message = 'You must deliver the structured output.'
143
+
144
+ return found_output
145
+
146
+ #def send_query_with_protocol(self, storage, task_schema, task_data, target_node, protocol_id, source):
147
+ # base_folder = Path(os.environ.get('STORAGE_PATH')) / 'protocol_documents'
148
+ # protocol_document = storage.load_protocol_document(base_folder, protocol_id)
149
+ # query_description = construct_query_description(protocol_document, task_schema, task_data)
150
+ # output_parameters = get_output_parameters(task_schema)
151
+ #
152
+ # return self.handle_conversation(PROTOCOL_QUERIER_PROMPT, query_description, target_node, protocol_id, source, output_parameters)
153
+
154
+ def send_query_without_protocol(self, task_schema, task_data, callback_fn):
155
+ query_description = construct_query_description(None, task_schema, task_data)
156
+ output_parameters = get_output_parameters(task_schema)
157
+
158
+ return self.handle_conversation(NL_QUERIER_PROMPT, query_description, callback_fn, None, None, output_parameters)
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ asknews==0.7.51
2
+ camel-ai==0.2.6
3
+ google-ai-generativelanguage==0.6.4
4
+ google-api-core==2.22.0
5
+ google-api-python-client==2.151.0
6
+ google-auth==2.36.0
7
+ google-auth-httplib2==0.2.0
8
+ google-cloud-core==2.4.1
9
+ google-cloud-storage==2.18.2
10
+ google-cloud-vision==3.8.0
11
+ google-crc32c==1.6.0
12
+ google-generativeai==0.6.0
13
+ gradio==5.5.0
14
+ gradio_client==1.4.2
15
+ huggingface-hub==0.26.2
16
+ langchain==0.3.7
17
+ langchain-anthropic==0.2.4
18
+ langchain-community==0.3.5
19
+ langchain-core==0.3.15
20
+ langchain-text-splitters==0.3.2
21
+ langgraph==0.2.45
22
+ langgraph-checkpoint==2.0.2
23
+ langgraph-sdk==0.1.35
24
+ python-dotenv==1.0.1
25
+ requests-oauthlib==1.3.1
26
+ sseclient-py==1.8.0
27
+ torch==2.5.1
28
+ transformers==4.46.2
responder.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The responder is a special toolformer that replies to a service based on a protocol document.
2
+ # It receives the protocol document and writes the response that must be sent to the system.
3
+
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+
8
+ from toolformers.base import Toolformer
9
+
10
+ # TODO: A tool to declare an error?
11
+
12
+
13
+ PROTOCOL_RESPONDER_PROMPT = 'You are ResponderGPT. You will receive a protocol document detailing how to respond to a query. '\
14
+ 'Use the provided functions to execute what is requested and provide the response according to the protocol\'s specification. ' \
15
+ 'Only reply with the response itself, with no additional information or escaping. Similarly, do not add any additional whitespace or formatting.'# \
16
+ # 'If you do not have enough information to reply, or if you cannot execute the request, reply with "ERROR" (without quotes).'
17
+
18
+ NL_RESPONDER_PROMPT = 'You are NaturalLanguageResponderGPT. You will receive a query from a user. ' \
19
+ 'Use the provided functions to execute what is requested and reply with a response (in natural language). ' \
20
+ 'Important: the user does not have the capacity to respond to follow-up questions, so if you think you have enough information to reply/execute the actions, do so.'
21
+ #'If you do not have enough information to reply, if you cannot execute the request, or if the request is invalid, reply with "ERROR" (without quotes).' \
22
+
23
+
24
+
25
+ class Responder:
26
+ def __init__(self, toolformer : Toolformer):
27
+ self.toolformer = toolformer
28
+
29
+ def reply_with_protocol_document(self, query, protocol_document, tools, additional_info):
30
+ print('===NL RESPONDER (WITH PROTOCOL)===')
31
+
32
+ conversation = self.toolformer.new_conversation(PROTOCOL_RESPONDER_PROMPT + additional_info, tools, category='conversation')
33
+
34
+ prompt = 'The protocol is the following:\n\n' + protocol_document + '\n\nThe query is the following:' + query
35
+
36
+ reply = conversation.chat(prompt, print_output=True)
37
+
38
+ print('======')
39
+
40
+ if 'error' in reply.lower().strip()[-10:]:
41
+ return {
42
+ 'status': 'error',
43
+ 'message': 'Error in the response'
44
+ }
45
+
46
+ return {
47
+ 'status': 'success',
48
+ 'body': reply
49
+ }
50
+
51
+
52
+ def reply_to_nl_query(self, query, tools, additional_info):
53
+ print('===NL RESPONDER (NO PROTOCOL)===')
54
+ print(NL_RESPONDER_PROMPT + additional_info)
55
+ print([tool.name for tool in tools])
56
+
57
+ conversation = self.toolformer.new_conversation(NL_RESPONDER_PROMPT + additional_info, tools, category='conversation')
58
+
59
+ print('Created conversation')
60
+ try:
61
+ reply = conversation.chat(query, print_output=True)
62
+ except Exception as e:
63
+ # Print traceback
64
+ import traceback
65
+ traceback.print_exc()
66
+ raise e
67
+ print('======')
68
+
69
+ if 'error' in reply.lower().strip()[-10:]:
70
+ return {
71
+ 'status': 'error',
72
+ }
73
+
74
+ return {
75
+ 'status': 'success',
76
+ 'body': reply
77
+ }
78
+
79
+
80
+ def reply_to_query(self, query, protocol_id, tools, additional_info):
81
+ print('Additional info:', additional_info)
82
+ if protocol_id is None:
83
+ return self.reply_to_nl_query(query, tools, additional_info)
84
+ #else:
85
+ # base_folder = Path(os.environ.get('STORAGE_PATH')) / 'protocol_documents'
86
+ # protocol_document = self.memory.load_protocol_document(base_folder, protocol_id)
87
+ # return self.reply_with_protocol_document(query, protocol_document, tools, additional_info)
toolformers/base.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import json
4
+
5
+ from google.generativeai.types import CallableFunctionDeclaration
6
+ import google.generativeai.types.content_types as content_types
7
+
8
+ from utils import add_params_and_annotations
9
+
10
+ class Parameter:
11
+ def __init__(self, name, description, required):
12
+ self.name = name
13
+ self.description = description
14
+ self.required = required
15
+
16
+ def as_openai_info(self):
17
+ pass
18
+
19
+ def as_standard_api(self):
20
+ pass
21
+
22
+ class StringParameter(Parameter):
23
+ def __init__(self, name, description, required):
24
+ super().__init__(name, description, required)
25
+
26
+ def as_openai_info(self):
27
+ return {
28
+ "type": "string",
29
+ "name": self.name,
30
+ "description": self.description
31
+ }
32
+
33
+ def as_standard_api(self):
34
+ return {
35
+ "type": "string",
36
+ "name": self.name,
37
+ "description": self.description,
38
+ "required": self.required
39
+ }
40
+
41
+ def as_natural_language(self):
42
+ return f'{self.name} (string{", required" if self.required else ""}): {self.description}.'
43
+
44
+ def as_documented_python(self):
45
+ return f'{self.name} (str{", required" if self.required else ""}): {self.description}.'
46
+
47
+ def as_gemini_tool(self):
48
+ return {
49
+ 'type': 'string',
50
+ 'description': self.description
51
+ }
52
+
53
+ @staticmethod
54
+ def from_standard_api(api_info):
55
+ return StringParameter(api_info["name"], api_info["description"], api_info["required"])
56
+
57
+ class EnumParameter(Parameter):
58
+ def __init__(self, name, description, values, required):
59
+ super().__init__(name, description, required)
60
+ self.values = values
61
+
62
+ def as_openai_info(self):
63
+ return {
64
+ "type": "string",
65
+ "description": self.description,
66
+ "values": self.values
67
+ }
68
+
69
+ def as_standard_api(self):
70
+ return {
71
+ "type": "enum",
72
+ "name": self.name,
73
+ "description": self.description,
74
+ "values": self.values,
75
+ "required": self.required
76
+ }
77
+
78
+ def as_natural_language(self):
79
+ return f'{self.name} (enum{", required" if self.required else ""}): {self.description}. Possible values: {", ".join(self.values)}'
80
+
81
+ def as_documented_python(self):
82
+ return f'{self.name} (str{", required" if self.required else ""}): {self.description}. Possible values: {", ".join(self.values)}'
83
+
84
+ def as_gemini_tool(self):
85
+ return {
86
+ 'description': self.description,
87
+ 'type': 'string',
88
+ 'enum': self.values
89
+ }
90
+
91
+ @staticmethod
92
+ def from_standard_api(api_info):
93
+ return EnumParameter(api_info["name"], api_info["description"], api_info["values"], api_info["required"])
94
+
95
+ class NumberParameter(Parameter):
96
+ def __init__(self, name, description, required):
97
+ super().__init__(name, description, required)
98
+
99
+ def as_openai_info(self):
100
+ return {
101
+ "type": "number",
102
+ "description": self.description
103
+ }
104
+
105
+ def as_standard_api(self):
106
+ return {
107
+ "type": "number",
108
+ "name": self.name,
109
+ "description": self.description,
110
+ "required": self.required
111
+ }
112
+
113
+ def as_natural_language(self):
114
+ return f'{self.name} (number): {self.description}'
115
+
116
+ def as_documented_python(self):
117
+ return f'{self.name} (number): {self.description}'
118
+
119
+ def as_gemini_tool(self):
120
+ return {
121
+ 'description': self.description,
122
+ 'type': 'number'
123
+ }
124
+
125
+ class ArrayParameter(Parameter):
126
+ def __init__(self, name, description, required, item_schema):
127
+ super().__init__(name, description, required)
128
+ self.item_schema = item_schema
129
+
130
+ def as_openai_info(self):
131
+ return {
132
+ "type": "array",
133
+ "description": self.description,
134
+ "items": self.item_schema
135
+ }
136
+
137
+ def as_standard_api(self):
138
+ return {
139
+ "type": "array",
140
+ "name": self.name,
141
+ "description": self.description,
142
+ "required": self.required,
143
+ "item_schema": self.item_schema
144
+ }
145
+
146
+ def as_natural_language(self):
147
+ return f'{self.name} (array): {self.description}. Each item should follow the JSON schema: {json.dumps(self.item_schema)}'
148
+
149
+ def as_documented_python(self):
150
+ return f'{self.name} (list): {self.description}. Each item should follow the JSON schema: {json.dumps(self.item_schema)}'
151
+
152
+ def as_gemini_tool(self):
153
+ return {
154
+ 'description': self.description,
155
+ 'type': 'array',
156
+ 'items': self.item_schema
157
+ }
158
+
159
+ def parameter_from_openai_api(parameter_name, schema, required):
160
+ if 'enum' in schema:
161
+ return EnumParameter(parameter_name, schema['description'], schema['enum'], required)
162
+ elif schema['type'] == 'string':
163
+ return StringParameter(parameter_name, schema['description'], required)
164
+ elif schema['type'] == 'number':
165
+ return NumberParameter(parameter_name, schema['description'], required)
166
+ elif schema['type'] == 'array':
167
+ return ArrayParameter(parameter_name, schema['description'], required, schema['items'])
168
+ else:
169
+ raise ValueError(f'Unknown parameter type: {schema["type"]}')
170
+
171
+ class Tool:
172
+ def __init__(self, name, description, parameters, function, output_schema=None):
173
+ self.name = name
174
+ self.description = description
175
+ self.parameters = parameters
176
+ self.function = function
177
+ self.output_schema = output_schema
178
+
179
+ def call_tool_for_toolformer(self, *args, **kwargs):
180
+ print(f'Toolformer called tool {self.name} with args {args} and kwargs {kwargs}')
181
+ # Unlike a call from a routine, this call catches exceptions and returns them as strings
182
+ try:
183
+ tool_reply = self.function(*args, **kwargs)
184
+ print(f'Tool {self.name} returned: {tool_reply}')
185
+ return tool_reply
186
+ except Exception as e:
187
+ print(f'Tool {self.name} failed with exception: {e}')
188
+ return 'Tool call failed: ' + str(e)
189
+
190
+ def as_openai_info(self):
191
+ return {
192
+ "type": "function",
193
+ "function": {
194
+ "name": self.name,
195
+ "description": self.description,
196
+ "parameters": {
197
+ "type" : "object",
198
+ "properties": {parameter.name : parameter.as_openai_info() for parameter in self.parameters},
199
+ "required": [parameter.name for parameter in self.parameters if parameter.required]
200
+ }
201
+ }
202
+ }
203
+
204
+ def as_gemini_tool(self) -> CallableFunctionDeclaration:
205
+ if len(self.parameters) == 0:
206
+ parameters = None
207
+ else:
208
+ parameters = {
209
+ 'type': 'object',
210
+ 'properties': {parameter.name: parameter.as_gemini_tool() for parameter in self.parameters},
211
+ 'required': [parameter.name for parameter in self.parameters if parameter.required]
212
+ }
213
+ return content_types.Tool([CallableFunctionDeclaration(
214
+ name=self.name,
215
+ description=self.description,
216
+ parameters=parameters,
217
+ function=self.call_tool_for_toolformer
218
+ )])
219
+
220
+ def as_llama_schema(self):
221
+ schema = {
222
+ 'name': self.name,
223
+ 'description': self.description,
224
+ 'parameters': {parameter.name : parameter.as_openai_info() for parameter in self.parameters},
225
+ 'required': [parameter.name for parameter in self.parameters if parameter.required]
226
+ }
227
+
228
+ if self.output_schema is not None:
229
+ schema['output_schema'] = self.output_schema
230
+
231
+ return schema
232
+
233
+ def as_natural_language(self):
234
+ print('Converting to natural language')
235
+ print('Number of parameters:', len(self.parameters))
236
+ nl = f'Function {self.name}: {self.description}. Parameters:\n'
237
+
238
+ if len(self.parameters) == 0:
239
+ nl += 'No parameters.'
240
+ else:
241
+ for parameter in self.parameters:
242
+ nl += '\t' + parameter.as_natural_language() + '\n'
243
+
244
+ if self.output_schema is not None:
245
+ nl += f'\Returns a dictionary with schema: {json.dumps(self.output_schema, indent=2)}'
246
+
247
+ return nl
248
+
249
+ def as_standard_api(self):
250
+ return {
251
+ "name": self.name,
252
+ "description": self.description,
253
+ "parameters": [parameter.as_standard_api() for parameter in self.parameters]
254
+ }
255
+
256
+ def as_documented_python(self):
257
+ documented_python = f'Tool {self.name}:\n\n{self.description}\nParameters:\n'
258
+
259
+ if len(self.parameters) == 0:
260
+ documented_python += 'No parameters.'
261
+ else:
262
+ for parameter in self.parameters:
263
+ documented_python += '\t' + parameter.as_documented_python() + '\n'
264
+
265
+ if self.output_schema is not None:
266
+ documented_python += f'\Returns a dictionary with schema: {json.dumps(self.output_schema, indent=2)}'
267
+
268
+ return documented_python
269
+
270
+ def as_executable_function(self):
271
+ # Create an actual function that can be called
272
+ def f(*args, **kwargs):
273
+ print('Routine called tool', self.name, 'with args', args, 'and kwargs', kwargs)
274
+ response = self.function(*args, **kwargs)
275
+ print('Tool', self.name, 'returned:', response)
276
+ return response
277
+
278
+ return f
279
+
280
+ def as_annotated_function(self):
281
+ def wrapped_fn(*args, **kwargs):
282
+ return self.call_tool_for_toolformer(*args, **kwargs)
283
+
284
+ parsed_parameters = {}
285
+
286
+ description = self.description
287
+
288
+ for parameter_name, parameter_schema in self.as_openai_info()['function']['parameters']['properties'].items():
289
+ if parameter_schema['type'] == 'string':
290
+ parsed_parameters[parameter_name] = (str, parameter_schema['description'])
291
+ elif parameter_schema['type'] == 'number':
292
+ parsed_parameters[parameter_name] = (float, parameter_schema['description'])
293
+ elif parameter_schema['type'] == 'object':
294
+ parsed_parameters[parameter_name] = (dict, parameter_schema['description'])
295
+
296
+ description += f'\n{parameter_name} has the schema:\n' + json.dumps(parameter_schema) + '\n'
297
+ else:
298
+ raise ValueError(f'Unknown parameter type: {parameter_schema["type"]}')
299
+
300
+ return_type = type(None)
301
+
302
+ if self.output_schema is not None:
303
+ #description += '\nOutput schema:\n' + json.dumps(self.output_schema)
304
+
305
+ if self.output_schema['type'] == 'string':
306
+ return_type = str
307
+ elif self.output_schema['type'] == 'number':
308
+ return_type = float
309
+ elif self.output_schema['type'] == 'object':
310
+ return_type = dict
311
+ else:
312
+ raise ValueError(f'Unknown output type: {self.output_schema["type"]}')
313
+
314
+ return add_params_and_annotations(
315
+ self.name, description, parsed_parameters, return_type)(wrapped_fn)
316
+
317
+ @staticmethod
318
+ def from_openai_info(info, func):
319
+ parameters = [parameter_from_openai_api(name, schema, name in info['function']['parameters']['required']) for name, schema in info['function']['parameters']['properties'].items()]
320
+ return Tool(info['function']['name'], info['function']['description'], parameters, func)
321
+
322
+
323
+ class Conversation(ABC):
324
+ @abstractmethod
325
+ def chat(self, message, role='user', print_output=True):
326
+ pass
327
+
328
+ class Toolformer(ABC):
329
+ @abstractmethod
330
+ def new_conversation(self, prompt, tools, category=None) -> Conversation:
331
+ pass
332
+
toolformers/camel.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+ from typing import List
4
+ import warnings
5
+
6
+ from toolformers.base import Conversation, Toolformer, Tool
7
+ from camel.messages import BaseMessage
8
+ from camel.models import ModelFactory
9
+ from camel.types import ModelPlatformType, ModelType
10
+ from camel.messages import BaseMessage as bm
11
+ from camel.agents import ChatAgent
12
+ from camel.toolkits.function_tool import FunctionTool
13
+ from camel.configs.openai_config import ChatGPTConfig
14
+
15
+ class CamelConversation(Conversation):
16
+ def __init__(self, toolformer, agent, category=None):
17
+ self.toolformer = toolformer
18
+ self.agent = agent
19
+ self.category = category
20
+
21
+ def chat(self, message, role='user', print_output=True):
22
+ agent_id = os.environ.get('AGENT_ID', None)
23
+
24
+ start_time = datetime.datetime.now()
25
+
26
+ if role == 'user':
27
+ formatted_message = BaseMessage.make_user_message('user', message)
28
+ elif role == 'assistant':
29
+ formatted_message = BaseMessage.make_assistant_message('assistant', message)
30
+ else:
31
+ raise ValueError('Role must be either "user" or "assistant".')
32
+
33
+ response = self.agent.step(formatted_message)
34
+
35
+ reply = response.msg.content
36
+
37
+ if print_output:
38
+ print(reply)
39
+
40
+ return reply
41
+
42
+ class CamelToolformer(Toolformer):
43
+ def __init__(self, model_platform, model_type, model_config_dict, name=None):
44
+ self.model_platform = model_platform
45
+ self.model_type = model_type
46
+ self.model_config_dict = model_config_dict
47
+ self._name = name
48
+
49
+ @property
50
+ def name(self):
51
+ if self._name is None:
52
+ return f'{self.model_platform.value}_{self.model_type.value}'
53
+ else:
54
+ return self._name
55
+
56
+ def new_conversation(self, prompt, tools : List[Tool], category=None) -> Conversation:
57
+ model = ModelFactory.create(
58
+ model_platform=self.model_platform,
59
+ model_type=self.model_type,
60
+ model_config_dict=self.model_config_dict
61
+ )
62
+
63
+ agent = ChatAgent(
64
+ model=model,
65
+ system_message=bm.make_assistant_message('system', prompt),
66
+ tools=[FunctionTool(tool.call_tool_for_toolformer, openai_tool_schema=tool.as_openai_info()) for tool in tools]
67
+ )
68
+
69
+ return CamelConversation(self, agent, category)
70
+
71
+ def make_openai_toolformer(model_type_internal):
72
+ if model_type_internal == 'gpt-4o':
73
+ model_type = ModelType.GPT_4O
74
+ elif model_type_internal == 'gpt-4o-mini':
75
+ model_type = ModelType.GPT_4O_MINI
76
+ else:
77
+ raise ValueError('Model type must be either "gpt-4o" or "gpt-4o-mini".')
78
+
79
+ #formatted_tools = [FunctionTool(tool.call_tool_for_toolformer, tool.as_openai_info()) for tool in tools]
80
+
81
+ return CamelToolformer(
82
+ model_platform=ModelPlatformType.OPENAI,
83
+ model_type=model_type,
84
+ model_config_dict=ChatGPTConfig(temperature=0.2).as_dict(),
85
+ name=model_type_internal
86
+ )
toolformers/gemini.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+ from random import random
4
+ import time
5
+ import traceback
6
+ from typing import List
7
+
8
+ from toolformers.base import Conversation, Tool, Toolformer
9
+
10
+ import google.generativeai as genai
11
+ from google.generativeai.generative_models import ChatSession
12
+
13
+ genai.configure(api_key=os.environ['GOOGLE_API_KEY'])
14
+
15
+ class GeminiConversation(Conversation):
16
+ def __init__(self, model_name, chat_agent : ChatSession, category=None):
17
+ self.model_name = model_name
18
+ self.chat_agent = chat_agent
19
+ self.category = category
20
+
21
+ def chat(self, message, role='user', print_output=True):
22
+ agent_id = os.environ.get('AGENT_ID', None)
23
+ time_start = datetime.datetime.now()
24
+
25
+ exponential_backoff_lower = 30
26
+ exponential_backoff_higher = 60
27
+ for i in range(5):
28
+ try:
29
+ response = self.chat_agent.send_message({
30
+ 'role': role,
31
+ 'parts': [
32
+ message
33
+ ]
34
+ })
35
+ break
36
+ except Exception as e:
37
+ print(e)
38
+ if '429' in str(e):
39
+ print('Rate limit exceeded. Waiting with random exponential backoff.')
40
+ if i < 4:
41
+ time.sleep(random() * (exponential_backoff_higher - exponential_backoff_lower) + exponential_backoff_lower)
42
+ exponential_backoff_lower *= 2
43
+ exponential_backoff_higher *= 2
44
+ elif 'candidates[0]' in traceback.format_exc():
45
+ # When Gemini has nothing to say, it raises an error with this message
46
+ print('No response')
47
+ return 'No response'
48
+ elif '500' in str(e):
49
+ # Sometimes Gemini just decides to return a 500 error for absolutely no reason. Retry.
50
+ print('500 error')
51
+ time.sleep(5)
52
+ traceback.print_exc()
53
+ else:
54
+ raise e
55
+
56
+ time_end = datetime.datetime.now()
57
+
58
+ usage_info = {
59
+ 'prompt_tokens': response.usage_metadata.prompt_token_count,
60
+ 'completion_tokens': response.usage_metadata.candidates_token_count
61
+ }
62
+
63
+ #send_usage_to_db(
64
+ # usage_info,
65
+ # time_start,
66
+ # time_end,
67
+ # agent_id,
68
+ # self.category,
69
+ # self.model_name
70
+ #)
71
+
72
+ reply = response.text
73
+
74
+ if print_output:
75
+ print(reply)
76
+
77
+ return reply
78
+
79
+ class GeminiToolformer(Toolformer):
80
+ def __init__(self, model_name):
81
+ self.model_name = model_name
82
+
83
+ def new_conversation(self, system_prompt, tools : List[Tool], category=None) -> Conversation:
84
+ print('Tools:')
85
+ print('\n'.join([str(tool.as_openai_info()) for tool in tools]))
86
+ model = genai.GenerativeModel(
87
+ model_name=self.model_name,
88
+ system_instruction=system_prompt,
89
+ tools=[tool.as_gemini_tool() for tool in tools]
90
+ )
91
+
92
+ chat = model.start_chat(enable_automatic_function_calling=True)
93
+
94
+ return GeminiConversation(self.model_name, chat, category)
95
+
96
+ def make_gemini_toolformer(model_name):
97
+ if model_name not in ['gemini-1.5-flash', 'gemini-1.5-pro']:
98
+ raise ValueError(f"Unknown model name: {model_name}")
99
+
100
+ return GeminiToolformer(model_name)
toolformers/huggingface_agent.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import InferenceClient
2
+ from transformers import ReactCodeAgent
3
+ from transformers import tool as function_to_tool
4
+
5
+ from toolformers.base import Conversation, Toolformer
6
+
7
+ class HuggingFaceConversation(Conversation):
8
+ def __init__(self, agent : ReactCodeAgent, prompt, category=None):
9
+ self.agent = agent
10
+ self.messages = [('system', prompt)]
11
+ self.category = category
12
+
13
+ def chat(self, message, role='user', print_output=True) -> str:
14
+ self.messages.append((role, message))
15
+
16
+ final_prompt = 'For context, here are the previous messages in the conversation:\n\n'
17
+
18
+ for role, message in self.messages:
19
+ final_prompt += f'{role.capitalize()}: {message}\n'
20
+
21
+ final_prompt += "Don't worry, you don't need to use the same format to reply. Stick with the Task:/Action:/etc. format.\n\n"
22
+
23
+ response = self.agent.run(final_prompt)
24
+ print(response)
25
+ return response
26
+
27
+ class HuggingFaceToolformer(Toolformer):
28
+ def __init__(self, model_name, max_tokens=2000):
29
+ self.model = InferenceClient(model=model_name)
30
+ self.max_tokens = max_tokens
31
+
32
+ def llm_engine(self, messages, stop_sequences=["Task"]) -> str:
33
+ response = self.model.chat_completion(messages, stop=stop_sequences, max_tokens=self.max_tokens)
34
+ answer = response.choices[0].message.content
35
+ return answer
36
+
37
+ def new_conversation(self, prompt, tools, category=None):
38
+ agent = ReactCodeAgent(tools=[function_to_tool(tool.as_annotated_function()) for tool in tools], llm_engine=self.llm_engine)
39
+ return HuggingFaceConversation(agent, prompt, category=category)
toolformers/langchain_agent.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ # Import relevant functionality
4
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
5
+ from langgraph.checkpoint.memory import MemorySaver
6
+ from langgraph.prebuilt import create_react_agent
7
+ from langchain_anthropic import ChatAnthropic
8
+
9
+ import sys
10
+ sys.path.append('.')
11
+ from toolformers.base import Tool as AgoraTool
12
+
13
+
14
+ from langchain_core.tools import tool as function_to_tool
15
+
16
+ from toolformers.base import StringParameter, Toolformer, Conversation
17
+
18
+
19
+
20
+ class LangChainConversation(Conversation):
21
+ def __init__(self, agent, messages, category=None):
22
+ self.agent = agent
23
+ self.messages = messages
24
+ self.category = category
25
+
26
+ def chat(self, message, role='user', print_output=True) -> str:
27
+ self.messages.append(HumanMessage(content=message))
28
+ final_message = ''
29
+ for chunk in self.agent.stream({"messages": self.messages}, stream_mode="values"):
30
+ print(chunk)
31
+ print("----")
32
+ for message in chunk['messages']:
33
+ if isinstance(message, AIMessage):
34
+ content = message.content
35
+ if isinstance(content, str):
36
+ final_message += content
37
+ else:
38
+ for content_chunk in content:
39
+ if isinstance(content_chunk, str):
40
+ final_message += content_chunk
41
+ #final_message += chunk['agent']['messages'].content
42
+
43
+ self.messages.append(AIMessage(content=final_message))
44
+ #print(final_message)
45
+
46
+ return final_message
47
+
48
+ class LangChainAnthropicToolformer(Toolformer):
49
+ def __init__(self, model_name, api_key):
50
+ self.model_name = model_name
51
+ self.api_key = api_key
52
+
53
+ def new_conversation(self, prompt, tools, category=None):
54
+ tools = [function_to_tool(tool.as_annotated_function()) for tool in tools]
55
+ model = ChatAnthropic(model_name=self.model_name, api_key=self.api_key)
56
+ agent_executor = create_react_agent(model, tools)
57
+
58
+ return LangChainConversation(agent_executor, [SystemMessage(prompt)], category)
59
+
60
+
61
+ #weather_tool = AgoraTool("WeatherForecastAPI", "A simple tool that returns the weather", [StringParameter(
62
+ # name="location",
63
+ # description="The name of the location for which the weather forecast is requested.",
64
+ # required=True
65
+ #)], lambda location: 'Sunny', {
66
+ # "type": "string"
67
+ #})
68
+ #
69
+ #tools = [agora_tool_to_langchain(weather_tool)]
70
+ #toolformer = LangChainToolformer("claude-3-sonnet-20240229", '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')
71
+ #conversation = toolformer.new_conversation('You are a weather bot', [weather_tool])
72
+ #
73
+ #print(conversation.chat('What is the weather in San Francisco?'))
toolformers/sambanova/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .core import SambanovaToolformer
toolformers/sambanova/api_gateway.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+ from typing import Optional
5
+
6
+ from langchain_community.llms.sambanova import SambaStudio
7
+ from langchain_core.language_models.llms import LLM
8
+
9
+ current_dir = os.path.dirname(os.path.abspath(__file__))
10
+ utils_dir = os.path.abspath(os.path.join(current_dir, '..'))
11
+ repo_dir = os.path.abspath(os.path.join(utils_dir, '..'))
12
+ sys.path.append(utils_dir)
13
+ sys.path.append(repo_dir)
14
+
15
+ from toolformers.sambanova.sambanova_langchain import SambaNovaCloud
16
+
17
+ EMBEDDING_MODEL = 'intfloat/e5-large-v2'
18
+ NORMALIZE_EMBEDDINGS = True
19
+
20
+ # Configure the logger
21
+ logging.basicConfig(
22
+ level=logging.INFO,
23
+ format='%(asctime)s [%(levelname)s] - %(message)s',
24
+ handlers=[
25
+ logging.StreamHandler(),
26
+ ],
27
+ )
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class APIGateway:
32
+ @staticmethod
33
+ def load_llm(
34
+ type: str,
35
+ streaming: bool = False,
36
+ coe: bool = False,
37
+ do_sample: Optional[bool] = None,
38
+ max_tokens_to_generate: Optional[int] = None,
39
+ temperature: Optional[float] = None,
40
+ select_expert: Optional[str] = None,
41
+ top_p: Optional[float] = None,
42
+ top_k: Optional[int] = None,
43
+ repetition_penalty: Optional[float] = None,
44
+ stop_sequences: Optional[str] = None,
45
+ process_prompt: Optional[bool] = False,
46
+ sambastudio_base_url: Optional[str] = None,
47
+ sambastudio_base_uri: Optional[str] = None,
48
+ sambastudio_project_id: Optional[str] = None,
49
+ sambastudio_endpoint_id: Optional[str] = None,
50
+ sambastudio_api_key: Optional[str] = None,
51
+ sambanova_url: Optional[str] = None,
52
+ sambanova_api_key: Optional[str] = None,
53
+ ) -> LLM:
54
+ """Loads a langchain Sambanova llm model given a type and parameters
55
+ Args:
56
+ type (str): wether to use sambastudio, or SambaNova Cloud model "sncloud"
57
+ streaming (bool): wether to use streaming method. Defaults to False.
58
+ coe (bool): whether to use coe model. Defaults to False.
59
+
60
+ do_sample (bool) : Optional wether to do sample.
61
+ max_tokens_to_generate (int) : Optional max number of tokens to generate.
62
+ temperature (float) : Optional model temperature.
63
+ select_expert (str) : Optional expert to use when using CoE models.
64
+ top_p (float) : Optional model top_p.
65
+ top_k (int) : Optional model top_k.
66
+ repetition_penalty (float) : Optional model repetition penalty.
67
+ stop_sequences (str) : Optional model stop sequences.
68
+ process_prompt (bool) : Optional default to false.
69
+
70
+ sambastudio_base_url (str): Optional SambaStudio environment URL".
71
+ sambastudio_base_uri (str): Optional SambaStudio-base-URI".
72
+ sambastudio_project_id (str): Optional SambaStudio project ID.
73
+ sambastudio_endpoint_id (str): Optional SambaStudio endpoint ID.
74
+ sambastudio_api_token (str): Optional SambaStudio endpoint API key.
75
+
76
+ sambanova_url (str): Optional SambaNova Cloud URL",
77
+ sambanova_api_key (str): Optional SambaNovaCloud API key.
78
+
79
+ Returns:
80
+ langchain llm model
81
+ """
82
+
83
+ if type == 'sambastudio':
84
+ envs = {
85
+ 'sambastudio_base_url': sambastudio_base_url,
86
+ 'sambastudio_base_uri': sambastudio_base_uri,
87
+ 'sambastudio_project_id': sambastudio_project_id,
88
+ 'sambastudio_endpoint_id': sambastudio_endpoint_id,
89
+ 'sambastudio_api_key': sambastudio_api_key,
90
+ }
91
+ envs = {k: v for k, v in envs.items() if v is not None}
92
+ if coe:
93
+ model_kwargs = {
94
+ 'do_sample': do_sample,
95
+ 'max_tokens_to_generate': max_tokens_to_generate,
96
+ 'temperature': temperature,
97
+ 'select_expert': select_expert,
98
+ 'top_p': top_p,
99
+ 'top_k': top_k,
100
+ 'repetition_penalty': repetition_penalty,
101
+ 'stop_sequences': stop_sequences,
102
+ 'process_prompt': process_prompt,
103
+ }
104
+ model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None}
105
+
106
+ llm = SambaStudio(
107
+ **envs,
108
+ streaming=streaming,
109
+ model_kwargs=model_kwargs,
110
+ )
111
+ else:
112
+ model_kwargs = {
113
+ 'do_sample': do_sample,
114
+ 'max_tokens_to_generate': max_tokens_to_generate,
115
+ 'temperature': temperature,
116
+ 'top_p': top_p,
117
+ 'top_k': top_k,
118
+ 'repetition_penalty': repetition_penalty,
119
+ 'stop_sequences': stop_sequences,
120
+ }
121
+ model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None}
122
+ llm = SambaStudio(
123
+ **envs,
124
+ streaming=streaming,
125
+ model_kwargs=model_kwargs,
126
+ )
127
+
128
+ elif type == 'sncloud':
129
+ envs = {
130
+ 'sambanova_url': sambanova_url,
131
+ 'sambanova_api_key': sambanova_api_key,
132
+ }
133
+ envs = {k: v for k, v in envs.items() if v is not None}
134
+ llm = SambaNovaCloud(
135
+ **envs,
136
+ max_tokens=max_tokens_to_generate,
137
+ model=select_expert,
138
+ temperature=temperature,
139
+ top_k=top_k,
140
+ top_p=top_p,
141
+ )
142
+
143
+ else:
144
+ raise ValueError(f"Invalid LLM API: {type}, only 'sncloud' and 'sambastudio' are supported.")
145
+
146
+ return llm
toolformers/sambanova/core.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+ from typing import List
4
+
5
+ from toolformers.base import Conversation, Toolformer, Tool
6
+ from toolformers.sambanova.function_calling import FunctionCallingLlm
7
+
8
+ class SambanovaConversation(Conversation):
9
+ def __init__(self, model_name, function_calling_llm : FunctionCallingLlm, category=None):
10
+ self.model_name = model_name
11
+ self.function_calling_llm = function_calling_llm
12
+ self.category = category
13
+
14
+ def chat(self, message, role='user', print_output=True):
15
+ if role != 'user':
16
+ raise ValueError('Role must be "user"')
17
+
18
+ agent_id = os.environ.get('AGENT_ID', None)
19
+
20
+ start_time = datetime.datetime.now()
21
+
22
+ response, usage_data = self.function_calling_llm.function_call_llm(message)
23
+
24
+ end_time = datetime.datetime.now()
25
+
26
+ print('Usage data:', usage_data)
27
+ if print_output:
28
+ print(response)
29
+
30
+ #send_usage_to_db(usage_data, start_time, end_time, agent_id, self.category, self.model_name)
31
+
32
+ return response
33
+
34
+ class SambanovaToolformer(Toolformer):
35
+ def __init__(self, model_name: str):
36
+ self.model_name = model_name
37
+
38
+ def new_conversation(self, system_prompt: str, tools: List[Tool], category=None) -> SambanovaConversation:
39
+ function_calling_llm = FunctionCallingLlm(system_prompt=system_prompt, tools=tools, select_expert=self.model_name)
40
+ return SambanovaConversation(self.model_name, function_calling_llm, category)
41
+
42
+ #def make_llama_toolformer(model_name, system_prompt: str, tools: List[Tool]):
43
+ # if model_name not in ['llama3-8b', 'llama3-70b', 'llama3-405b']:
44
+ # raise ValueError(f"Unknown model name: {model_name}")
45
+ #
46
+ # return SambanovaToolformer(model_name, system_prompt, tools)
toolformers/sambanova/function_calling.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ from random import random
5
+ from pprint import pprint
6
+ import time
7
+ from typing import List, Optional, Union
8
+
9
+ from langchain_core.messages.ai import AIMessage
10
+ from langchain_core.messages.human import HumanMessage
11
+ from langchain_core.messages.tool import ToolMessage
12
+ from langchain_core.prompts import ChatPromptTemplate
13
+ from langchain_core.runnables import RunnableLambda
14
+
15
+ from toolformers.base import Tool, StringParameter
16
+ from toolformers.sambanova.api_gateway import APIGateway
17
+
18
+ from toolformers.sambanova.utils import get_total_usage, usage_tracker
19
+
20
+
21
+ FUNCTION_CALLING_SYSTEM_PROMPT = """You have access to the following tools:
22
+
23
+ {tools}
24
+
25
+ You can call one or more tools by adding a <ToolCalls> section to your message. For example:
26
+ <ToolCalls>
27
+ ```json
28
+ [{{
29
+ "tool": <name of the selected tool>,
30
+ "tool_input": <parameters for the selected tool, matching the tool's JSON schema>
31
+ }}]
32
+ ```
33
+ </ToolCalls>
34
+
35
+ Note that you can select multiple tools at once by adding more objects to the list. Do not add \
36
+ multiple <ToolCalls> sections to the same message.
37
+ You will see the invocation of the tools in the response.
38
+
39
+
40
+ Think step by step
41
+ Do not call a tool if the input depends on another tool output that you do not have yet.
42
+ Do not try to answer until you get all the tools output, if you do not have an answer yet, you can continue calling tools until you do.
43
+ Your answer should be in the same language as the initial query.
44
+
45
+ """ # noqa E501
46
+
47
+
48
+ conversational_response = Tool(
49
+ name='ConversationalResponse',
50
+ description='Respond conversationally only if no other tools should be called for a given query, or if you have a final answer. Response must be in the same language as the user query.',
51
+ parameters=[StringParameter(name='response', description='Conversational response to the user. Must be in the same language as the user query.', required=True)],
52
+ function=None
53
+ )
54
+
55
+
56
+ class FunctionCallingLlm:
57
+ """
58
+ function calling llm class
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ tools: Optional[Union[Tool, List[Tool]]] = None,
64
+ default_tool: Optional[Tool] = None,
65
+ system_prompt: Optional[str] = None,
66
+ prod_mode: bool = False,
67
+ api: str = 'sncloud',
68
+ coe: bool = False,
69
+ do_sample: bool = False,
70
+ max_tokens_to_generate: Optional[int] = None,
71
+ temperature: float = 0.2,
72
+ select_expert: Optional[str] = None,
73
+ ) -> None:
74
+ """
75
+ Args:
76
+ tools (Optional[Union[Tool, List[Tool]]]): The tools to use.
77
+ default_tool (Optional[Tool]): The default tool to use.
78
+ defaults to ConversationalResponse
79
+ system_prompt (Optional[str]): The system prompt to use. defaults to FUNCTION_CALLING_SYSTEM_PROMPT
80
+ prod_mode (bool): Whether to use production mode. Defaults to False.
81
+ api (str): The api to use. Defaults to 'sncloud'.
82
+ coe (bool): Whether to use coe. Defaults to False.
83
+ do_sample (bool): Whether to do sample. Defaults to False.
84
+ max_tokens_to_generate (Optional[int]): The max tokens to generate. If None, the model will attempt to use the maximum available tokens.
85
+ temperature (float): The model temperature. Defaults to 0.2.
86
+ select_expert (Optional[str]): The expert to use. Defaults to None.
87
+ """
88
+ self.prod_mode = prod_mode
89
+ sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY')
90
+ self.api = api
91
+ self.llm = APIGateway.load_llm(
92
+ type=api,
93
+ streaming=True,
94
+ coe=coe,
95
+ do_sample=do_sample,
96
+ max_tokens_to_generate=max_tokens_to_generate,
97
+ temperature=temperature,
98
+ select_expert=select_expert,
99
+ process_prompt=False,
100
+ sambanova_api_key=sambanova_api_key,
101
+ )
102
+
103
+ if isinstance(tools, Tool):
104
+ tools = [tools]
105
+ self.tools = tools
106
+ if system_prompt is None:
107
+ system_prompt = ''
108
+
109
+ system_prompt = system_prompt.replace('{','{{').replace('}', '}}')
110
+
111
+ if len(self.tools) > 0:
112
+ system_prompt += '\n\n'
113
+ system_prompt += FUNCTION_CALLING_SYSTEM_PROMPT
114
+ self.system_prompt = system_prompt
115
+
116
+ if default_tool is None:
117
+ default_tool = conversational_response
118
+
119
+ def execute(self, invoked_tools: List[dict]) -> tuple[bool, List[str]]:
120
+ """
121
+ Given a list of tool executions the llm return as required
122
+ execute them given the name with the mane in tools_map and the input arguments
123
+ if there is only one tool call and it is default conversational one, the response is marked as final response
124
+
125
+ Args:
126
+ invoked_tools (List[dict]): The list of tool executions generated by the LLM.
127
+ """
128
+ if self.tools is not None:
129
+ tools_map = {tool.name.lower(): tool for tool in self.tools}
130
+ else:
131
+ tools_map = {}
132
+ tool_msg = "Tool '{name}' response: {response}"
133
+ tools_msgs = []
134
+ if len(invoked_tools) == 1 and invoked_tools[0]['tool'].lower() == 'conversationalresponse':
135
+ final_answer = True
136
+ return final_answer, [invoked_tools[0]['tool_input']['response']]
137
+
138
+ final_answer = False
139
+
140
+ for tool in invoked_tools:
141
+ if tool['tool'].lower() == 'invocationerror':
142
+ tools_msgs.append(f'Tool invocation error: {tool["tool_input"]}')
143
+ elif tool['tool'].lower() != 'conversationalresponse':
144
+ print(f"\n\n---\nTool {tool['tool'].lower()} invoked with input {tool['tool_input']}\n")
145
+
146
+ if tool['tool'].lower() not in tools_map:
147
+ tools_msgs.append(f'Tool {tool["tool"]} not found')
148
+ else:
149
+ response = tools_map[tool['tool'].lower()].call_tool_for_toolformer(**tool['tool_input'])
150
+ # print(f'Tool response: {str(response)}\n---\n\n')
151
+ tools_msgs.append(tool_msg.format(name=tool['tool'], response=str(response)))
152
+ return final_answer, tools_msgs
153
+
154
+ def json_finder(self, input_string: str) -> Optional[str]:
155
+ """
156
+ find json structures in an LLM string response, if bad formatted using LLM to correct it
157
+
158
+ Args:
159
+ input_string (str): The string to find the json structure in.
160
+ """
161
+
162
+ # 1. Ideal pattern: correctly surrounded by <ToolCalls> tags
163
+ json_pattern_1 = re.compile(r'<ToolCalls\>(.*)</ToolCalls\>', re.DOTALL + re.IGNORECASE)
164
+ # 2. Sometimes the closing tag is missing
165
+ json_pattern_2 = re.compile(r'<ToolCalls\>(.*)', re.DOTALL + re.IGNORECASE)
166
+ # 3. Sometimes it accidentally uses <ToolCall> instead of <ToolCalls>
167
+ json_pattern_3 = re.compile(r'<ToolCall\>(.*)</ToolCall\>', re.DOTALL + re.IGNORECASE)
168
+ # 4. Sometimes it accidentally uses <ToolCall> instead of <ToolCalls> and the closing tag is missing
169
+ json_pattern_4 = re.compile(r'<ToolCall\>(.*)', re.DOTALL + re.IGNORECASE)
170
+
171
+ # Find the first JSON structure in the string
172
+ json_match = json_pattern_1.search(input_string) or json_pattern_2.search(input_string) or json_pattern_3.search(input_string) or json_pattern_4.search(input_string)
173
+ if json_match:
174
+ json_str = json_match.group(1)
175
+
176
+ # 1. Outermost list of JSON object
177
+ call_pattern_1 = re.compile(r'\[.*\]', re.DOTALL)
178
+ # 2. Outermost JSON object
179
+ call_pattern_2 = re.compile(r'\{.*\}', re.DOTALL)
180
+
181
+ call_match_1 = call_pattern_1.search(json_str)
182
+ call_match_2 = call_pattern_2.search(json_str)
183
+
184
+ if call_match_1:
185
+ json_str = call_match_1.group(0)
186
+ try:
187
+ return json.loads(json_str)
188
+ except Exception as e:
189
+ return [{'tool': 'InvocationError', 'tool_input' : str(e)}]
190
+ elif call_match_2:
191
+ json_str = call_match_2.group(0)
192
+ try:
193
+ return [json.loads(json_str)]
194
+ except Exception as e:
195
+ return [{'tool': 'InvocationError', 'tool_input' : str(e)}]
196
+ else:
197
+ return [{'tool': 'InvocationError', 'tool_input' : 'Could not find JSON object in the <ToolCalls> section'}]
198
+ else:
199
+ dummy_json_response = [{'tool': 'ConversationalResponse', 'tool_input': {'response': input_string}}]
200
+ json_str = dummy_json_response
201
+ return json_str
202
+
203
+ def msgs_to_llama3_str(self, msgs: list) -> str:
204
+ """
205
+ convert a list of langchain messages with roles to expected LLmana 3 input
206
+
207
+ Args:
208
+ msgs (list): The list of langchain messages.
209
+ """
210
+ formatted_msgs = []
211
+ for msg in msgs:
212
+ if msg.type == 'system':
213
+ sys_placeholder = (
214
+ '<|begin_of_text|><|start_header_id|>system<|end_header_id|>system<|end_header_id|> {msg}'
215
+ )
216
+ formatted_msgs.append(sys_placeholder.format(msg=msg.content))
217
+ elif msg.type == 'human':
218
+ human_placeholder = '<|eot_id|><|start_header_id|>user<|end_header_id|>\nUser: {msg} <|eot_id|><|start_header_id|>assistant<|end_header_id|>\nAssistant:' # noqa E501
219
+ formatted_msgs.append(human_placeholder.format(msg=msg.content))
220
+ elif msg.type == 'ai':
221
+ assistant_placeholder = '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\nAssistant: {msg}'
222
+ formatted_msgs.append(assistant_placeholder.format(msg=msg.content))
223
+ elif msg.type == 'tool':
224
+ tool_placeholder = '<|eot_id|><|start_header_id|>tools<|end_header_id|>\n{msg} <|eot_id|><|start_header_id|>assistant<|end_header_id|>\nAssistant:' # noqa E501
225
+ formatted_msgs.append(tool_placeholder.format(msg=msg.content))
226
+ else:
227
+ raise ValueError(f'Invalid message type: {msg.type}')
228
+ return '\n'.join(formatted_msgs)
229
+
230
+ def msgs_to_sncloud(self, msgs: list) -> list:
231
+ """
232
+ convert a list of langchain messages with roles to expected FastCoE input
233
+
234
+ Args:
235
+ msgs (list): The list of langchain messages.
236
+ """
237
+ formatted_msgs = []
238
+ for msg in msgs:
239
+ if msg.type == 'system':
240
+ formatted_msgs.append({'role': 'system', 'content': msg.content})
241
+ elif msg.type == 'human':
242
+ formatted_msgs.append({'role': 'user', 'content': msg.content})
243
+ elif msg.type == 'ai':
244
+ formatted_msgs.append({'role': 'assistant', 'content': msg.content})
245
+ elif msg.type == 'tool':
246
+ formatted_msgs.append({'role': 'tools', 'content': msg.content})
247
+ else:
248
+ raise ValueError(f'Invalid message type: {msg.type}')
249
+ return json.dumps(formatted_msgs)
250
+
251
+ def function_call_llm(self, query: str, max_it: int = 10, debug: bool = False) -> str:
252
+ """
253
+ invocation method for function calling workflow
254
+
255
+ Args:
256
+ query (str): The query to execute.
257
+ max_it (int, optional): The maximum number of iterations. Defaults to 5.
258
+ debug (bool, optional): Whether to print debug information. Defaults to False.
259
+ """
260
+ function_calling_chat_template = ChatPromptTemplate.from_messages([('system', self.system_prompt)])
261
+ tools_schemas = [tool.as_llama_schema() for tool in self.tools]
262
+
263
+ history = function_calling_chat_template.format_prompt(tools=tools_schemas).to_messages()
264
+
265
+ history.append(HumanMessage(query))
266
+ tool_call_id = 0 # identification for each tool calling required to create ToolMessages
267
+ with usage_tracker():
268
+
269
+ for i in range(max_it):
270
+ json_parsing_chain = RunnableLambda(self.json_finder)
271
+
272
+ if self.api == 'sncloud':
273
+ prompt = self.msgs_to_sncloud(history)
274
+ else:
275
+ prompt = self.msgs_to_llama3_str(history)
276
+ # print(f'\n\n---\nCalling function calling LLM with prompt: \n{prompt}\n')
277
+
278
+ exponential_backoff_lower = 30
279
+ exponential_backoff_higher = 60
280
+ llm_response = None
281
+ for _ in range(5):
282
+ try:
283
+ llm_response = self.llm.invoke(prompt, stream_options={'include_usage': True})
284
+ break
285
+ except Exception as e:
286
+ if '429' in str(e):
287
+ print('Rate limit exceeded. Waiting with random exponential backoff.')
288
+ time.sleep(random() * (exponential_backoff_higher - exponential_backoff_lower) + exponential_backoff_lower)
289
+ exponential_backoff_lower *= 2
290
+ exponential_backoff_higher *= 2
291
+ else:
292
+ raise e
293
+
294
+ print('LLM response:', llm_response)
295
+
296
+ # print(f'\nFunction calling LLM response: \n{llm_response}\n---\n')
297
+ parsed_tools_llm_response = json_parsing_chain.invoke(llm_response)
298
+
299
+ history.append(AIMessage(llm_response))
300
+ final_answer, tools_msgs = self.execute(parsed_tools_llm_response)
301
+ if final_answer: # if response was marked as final response in execution
302
+ final_response = tools_msgs[0]
303
+ if debug:
304
+ print('\n\n---\nFinal function calling LLM history: \n')
305
+ pprint(f'{history}')
306
+ return final_response, get_total_usage()
307
+ else:
308
+ history.append(ToolMessage('\n'.join(tools_msgs), tool_call_id=tool_call_id))
309
+ tool_call_id += 1
310
+
311
+
312
+ raise Exception('Not a final response yet', history)
toolformers/sambanova/sambanova_langchain.py ADDED
@@ -0,0 +1,827 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Langchain Wrapper around Sambanova LLM APIs."""
2
+
3
+ import json
4
+ from typing import Any, Dict, Generator, Iterator, List, Optional, Union
5
+
6
+ import requests
7
+ from langchain_core.callbacks.manager import CallbackManagerForLLMRun
8
+ from langchain_core.language_models.llms import LLM
9
+ from langchain_core.outputs import GenerationChunk
10
+ from langchain_core.pydantic_v1 import Extra
11
+ from langchain_core.utils import get_from_dict_or_env, pre_init
12
+ from langchain_core.runnables import RunnableConfig, ensure_config
13
+ from langchain_core.language_models.base import (
14
+ LanguageModelInput,
15
+ )
16
+ from toolformers.sambanova.utils import append_to_usage_tracker
17
+
18
+
19
+ class SSEndpointHandler:
20
+ """
21
+ SambaNova Systems Interface for SambaStudio model endpoints.
22
+
23
+ :param str host_url: Base URL of the DaaS API service
24
+ """
25
+
26
+ def __init__(self, host_url: str, api_base_uri: str):
27
+ """
28
+ Initialize the SSEndpointHandler.
29
+
30
+ :param str host_url: Base URL of the DaaS API service
31
+ :param str api_base_uri: Base URI of the DaaS API service
32
+ """
33
+ self.host_url = host_url
34
+ self.api_base_uri = api_base_uri
35
+ self.http_session = requests.Session()
36
+
37
+ def _process_response(self, response: requests.Response) -> Dict:
38
+ """
39
+ Processes the API response and returns the resulting dict.
40
+
41
+ All resulting dicts, regardless of success or failure, will contain the
42
+ `status_code` key with the API response status code.
43
+
44
+ If the API returned an error, the resulting dict will contain the key
45
+ `detail` with the error message.
46
+
47
+ If the API call was successful, the resulting dict will contain the key
48
+ `data` with the response data.
49
+
50
+ :param requests.Response response: the response object to process
51
+ :return: the response dict
52
+ :type: dict
53
+ """
54
+ result: Dict[str, Any] = {}
55
+ try:
56
+ result = response.json()
57
+ except Exception as e:
58
+ result['detail'] = str(e)
59
+ if 'status_code' not in result:
60
+ result['status_code'] = response.status_code
61
+ return result
62
+
63
+ def _process_streaming_response(
64
+ self,
65
+ response: requests.Response,
66
+ ) -> Generator[Dict, None, None]:
67
+ """Process the streaming response"""
68
+ if 'api/predict/nlp' in self.api_base_uri:
69
+ try:
70
+ import sseclient
71
+ except ImportError:
72
+ raise ImportError(
73
+ 'could not import sseclient library' 'Please install it with `pip install sseclient-py`.'
74
+ )
75
+ client = sseclient.SSEClient(response)
76
+ close_conn = False
77
+ for event in client.events():
78
+ if event.event == 'error_event':
79
+ close_conn = True
80
+ chunk = {
81
+ 'event': event.event,
82
+ 'data': event.data,
83
+ 'status_code': response.status_code,
84
+ }
85
+ yield chunk
86
+ if close_conn:
87
+ client.close()
88
+ elif 'api/v2/predict/generic' in self.api_base_uri or 'api/predict/generic' in self.api_base_uri:
89
+ try:
90
+ for line in response.iter_lines():
91
+ chunk = json.loads(line)
92
+ if 'status_code' not in chunk:
93
+ chunk['status_code'] = response.status_code
94
+ yield chunk
95
+ except Exception as e:
96
+ raise RuntimeError(f'Error processing streaming response: {e}')
97
+ else:
98
+ raise ValueError(f'handling of endpoint uri: {self.api_base_uri} not implemented')
99
+
100
+ def _get_full_url(self, path: str) -> str:
101
+ """
102
+ Return the full API URL for a given path.
103
+
104
+ :param str path: the sub-path
105
+ :returns: the full API URL for the sub-path
106
+ :type: str
107
+ """
108
+ return f'{self.host_url}/{self.api_base_uri}/{path}'
109
+
110
+ def nlp_predict(
111
+ self,
112
+ project: str,
113
+ endpoint: str,
114
+ key: str,
115
+ input: Union[List[str], str],
116
+ params: Optional[str] = '',
117
+ stream: bool = False,
118
+ ) -> Dict:
119
+ """
120
+ NLP predict using inline input string.
121
+
122
+ :param str project: Project ID in which the endpoint exists
123
+ :param str endpoint: Endpoint ID
124
+ :param str key: API Key
125
+ :param str input_str: Input string
126
+ :param str params: Input params string
127
+ :returns: Prediction results
128
+ :type: dict
129
+ """
130
+ if isinstance(input, str):
131
+ input = [input]
132
+ if 'api/predict/nlp' in self.api_base_uri:
133
+ if params:
134
+ data = {'inputs': input, 'params': json.loads(params)}
135
+ else:
136
+ data = {'inputs': input}
137
+ elif 'api/v2/predict/generic' in self.api_base_uri:
138
+ items = [{'id': f'item{i}', 'value': item} for i, item in enumerate(input)]
139
+ if params:
140
+ data = {'items': items, 'params': json.loads(params)}
141
+ else:
142
+ data = {'items': items}
143
+ elif 'api/predict/generic' in self.api_base_uri:
144
+ if params:
145
+ data = {'instances': input, 'params': json.loads(params)}
146
+ else:
147
+ data = {'instances': input}
148
+ else:
149
+ raise ValueError(f'handling of endpoint uri: {self.api_base_uri} not implemented')
150
+ response = self.http_session.post(
151
+ self._get_full_url(f'{project}/{endpoint}'),
152
+ headers={'key': key},
153
+ json=data,
154
+ )
155
+ return self._process_response(response)
156
+
157
+ def nlp_predict_stream(
158
+ self,
159
+ project: str,
160
+ endpoint: str,
161
+ key: str,
162
+ input: Union[List[str], str],
163
+ params: Optional[str] = '',
164
+ ) -> Iterator[Dict]:
165
+ """
166
+ NLP predict using inline input string.
167
+
168
+ :param str project: Project ID in which the endpoint exists
169
+ :param str endpoint: Endpoint ID
170
+ :param str key: API Key
171
+ :param str input_str: Input string
172
+ :param str params: Input params string
173
+ :returns: Prediction results
174
+ :type: dict
175
+ """
176
+ if 'api/predict/nlp' in self.api_base_uri:
177
+ if isinstance(input, str):
178
+ input = [input]
179
+ if params:
180
+ data = {'inputs': input, 'params': json.loads(params)}
181
+ else:
182
+ data = {'inputs': input}
183
+ elif 'api/v2/predict/generic' in self.api_base_uri:
184
+ if isinstance(input, str):
185
+ input = [input]
186
+ items = [{'id': f'item{i}', 'value': item} for i, item in enumerate(input)]
187
+ if params:
188
+ data = {'items': items, 'params': json.loads(params)}
189
+ else:
190
+ data = {'items': items}
191
+ elif 'api/predict/generic' in self.api_base_uri:
192
+ if isinstance(input, list):
193
+ input = input[0]
194
+ if params:
195
+ data = {'instance': input, 'params': json.loads(params)}
196
+ else:
197
+ data = {'instance': input}
198
+ else:
199
+ raise ValueError(f'handling of endpoint uri: {self.api_base_uri} not implemented')
200
+ # Streaming output
201
+ response = self.http_session.post(
202
+ self._get_full_url(f'stream/{project}/{endpoint}'),
203
+ headers={'key': key},
204
+ json=data,
205
+ stream=True,
206
+ )
207
+ for chunk in self._process_streaming_response(response):
208
+ yield chunk
209
+
210
+
211
+ class SambaStudio(LLM):
212
+ """
213
+ SambaStudio large language models.
214
+
215
+ To use, you should have the environment variables
216
+ ``SAMBASTUDIO_BASE_URL`` set with your SambaStudio environment URL.
217
+ ``SAMBASTUDIO_BASE_URI`` set with your SambaStudio api base URI.
218
+ ``SAMBASTUDIO_PROJECT_ID`` set with your SambaStudio project ID.
219
+ ``SAMBASTUDIO_ENDPOINT_ID`` set with your SambaStudio endpoint ID.
220
+ ``SAMBASTUDIO_API_KEY`` set with your SambaStudio endpoint API key.
221
+
222
+ https://sambanova.ai/products/enterprise-ai-platform-sambanova-suite
223
+
224
+ read extra documentation in https://docs.sambanova.ai/sambastudio/latest/index.html
225
+
226
+ Example:
227
+ .. code-block:: python
228
+
229
+ from langchain_community.llms.sambanova import SambaStudio
230
+ SambaStudio(
231
+ sambastudio_base_url="your-SambaStudio-environment-URL",
232
+ sambastudio_base_uri="your-SambaStudio-base-URI",
233
+ sambastudio_project_id="your-SambaStudio-project-ID",
234
+ sambastudio_endpoint_id="your-SambaStudio-endpoint-ID",
235
+ sambastudio_api_key="your-SambaStudio-endpoint-API-key,
236
+ streaming=False
237
+ model_kwargs={
238
+ "do_sample": False,
239
+ "max_tokens_to_generate": 1000,
240
+ "temperature": 0.7,
241
+ "top_p": 1.0,
242
+ "repetition_penalty": 1,
243
+ "top_k": 50,
244
+ #"process_prompt": False,
245
+ #"select_expert": "Meta-Llama-3-8B-Instruct"
246
+ },
247
+ )
248
+ """
249
+
250
+ sambastudio_base_url: str = ''
251
+ """Base url to use"""
252
+
253
+ sambastudio_base_uri: str = ''
254
+ """endpoint base uri"""
255
+
256
+ sambastudio_project_id: str = ''
257
+ """Project id on sambastudio for model"""
258
+
259
+ sambastudio_endpoint_id: str = ''
260
+ """endpoint id on sambastudio for model"""
261
+
262
+ sambastudio_api_key: str = ''
263
+ """sambastudio api key"""
264
+
265
+ model_kwargs: Optional[dict] = None
266
+ """Key word arguments to pass to the model."""
267
+
268
+ streaming: Optional[bool] = False
269
+ """Streaming flag to get streamed response."""
270
+
271
+ class Config:
272
+ """Configuration for this pydantic object."""
273
+
274
+ extra = 'forbid'#Extra.forbid
275
+
276
+ @classmethod
277
+ def is_lc_serializable(cls) -> bool:
278
+ return True
279
+
280
+ @property
281
+ def _identifying_params(self) -> Dict[str, Any]:
282
+ """Get the identifying parameters."""
283
+ return {**{'model_kwargs': self.model_kwargs}}
284
+
285
+ @property
286
+ def _llm_type(self) -> str:
287
+ """Return type of llm."""
288
+ return 'Sambastudio LLM'
289
+
290
+ @pre_init
291
+ def validate_environment(cls, values: Dict) -> Dict:
292
+ """Validate that api key and python package exists in environment."""
293
+ values['sambastudio_base_url'] = get_from_dict_or_env(values, 'sambastudio_base_url', 'SAMBASTUDIO_BASE_URL')
294
+ values['sambastudio_base_uri'] = get_from_dict_or_env(
295
+ values,
296
+ 'sambastudio_base_uri',
297
+ 'SAMBASTUDIO_BASE_URI',
298
+ default='api/predict/generic',
299
+ )
300
+ values['sambastudio_project_id'] = get_from_dict_or_env(
301
+ values, 'sambastudio_project_id', 'SAMBASTUDIO_PROJECT_ID'
302
+ )
303
+ values['sambastudio_endpoint_id'] = get_from_dict_or_env(
304
+ values, 'sambastudio_endpoint_id', 'SAMBASTUDIO_ENDPOINT_ID'
305
+ )
306
+ values['sambastudio_api_key'] = get_from_dict_or_env(values, 'sambastudio_api_key', 'SAMBASTUDIO_API_KEY')
307
+ return values
308
+
309
+ def _get_tuning_params(self, stop: Optional[List[str]]) -> str:
310
+ """
311
+ Get the tuning parameters to use when calling the LLM.
312
+
313
+ Args:
314
+ stop: Stop words to use when generating. Model output is cut off at the
315
+ first occurrence of any of the stop substrings.
316
+
317
+ Returns:
318
+ The tuning parameters as a JSON string.
319
+ """
320
+ _model_kwargs = self.model_kwargs or {}
321
+ _kwarg_stop_sequences = _model_kwargs.get('stop_sequences', [])
322
+ _stop_sequences = stop or _kwarg_stop_sequences
323
+ # if not _kwarg_stop_sequences:
324
+ # _model_kwargs["stop_sequences"] = ",".join(
325
+ # f'"{x}"' for x in _stop_sequences
326
+ # )
327
+ if 'api/v2/predict/generic' in self.sambastudio_base_uri:
328
+ tuning_params_dict = _model_kwargs
329
+ else:
330
+ tuning_params_dict = {k: {'type': type(v).__name__, 'value': str(v)} for k, v in (_model_kwargs.items())}
331
+ # _model_kwargs["stop_sequences"] = _kwarg_stop_sequences
332
+ tuning_params = json.dumps(tuning_params_dict)
333
+ return tuning_params
334
+
335
+ def _handle_nlp_predict(self, sdk: SSEndpointHandler, prompt: Union[List[str], str], tuning_params: str) -> str:
336
+ """
337
+ Perform an NLP prediction using the SambaStudio endpoint handler.
338
+
339
+ Args:
340
+ sdk: The SSEndpointHandler to use for the prediction.
341
+ prompt: The prompt to use for the prediction.
342
+ tuning_params: The tuning parameters to use for the prediction.
343
+
344
+ Returns:
345
+ The prediction result.
346
+
347
+ Raises:
348
+ ValueError: If the prediction fails.
349
+ """
350
+ response = sdk.nlp_predict(
351
+ self.sambastudio_project_id,
352
+ self.sambastudio_endpoint_id,
353
+ self.sambastudio_api_key,
354
+ prompt,
355
+ tuning_params,
356
+ )
357
+ if response['status_code'] != 200:
358
+ optional_detail = response.get('detail')
359
+ if optional_detail:
360
+ raise RuntimeError(
361
+ f"Sambanova /complete call failed with status code "
362
+ f"{response['status_code']}.\n Details: {optional_detail}"
363
+ )
364
+ else:
365
+ raise RuntimeError(
366
+ f"Sambanova /complete call failed with status code "
367
+ f"{response['status_code']}.\n response {response}"
368
+ )
369
+ if 'api/predict/nlp' in self.sambastudio_base_uri:
370
+ return response['data'][0]['completion']
371
+ elif 'api/v2/predict/generic' in self.sambastudio_base_uri:
372
+ return response['items'][0]['value']['completion']
373
+ elif 'api/predict/generic' in self.sambastudio_base_uri:
374
+ return response['predictions'][0]['completion']
375
+ else:
376
+ raise ValueError(f'handling of endpoint uri: {self.sambastudio_base_uri} not implemented')
377
+
378
+ def _handle_completion_requests(self, prompt: Union[List[str], str], stop: Optional[List[str]]) -> str:
379
+ """
380
+ Perform a prediction using the SambaStudio endpoint handler.
381
+
382
+ Args:
383
+ prompt: The prompt to use for the prediction.
384
+ stop: stop sequences.
385
+
386
+ Returns:
387
+ The prediction result.
388
+
389
+ Raises:
390
+ ValueError: If the prediction fails.
391
+ """
392
+ ss_endpoint = SSEndpointHandler(self.sambastudio_base_url, self.sambastudio_base_uri)
393
+ tuning_params = self._get_tuning_params(stop)
394
+ return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params)
395
+
396
+ def _handle_nlp_predict_stream(
397
+ self, sdk: SSEndpointHandler, prompt: Union[List[str], str], tuning_params: str
398
+ ) -> Iterator[GenerationChunk]:
399
+ """
400
+ Perform a streaming request to the LLM.
401
+
402
+ Args:
403
+ sdk: The SVEndpointHandler to use for the prediction.
404
+ prompt: The prompt to use for the prediction.
405
+ tuning_params: The tuning parameters to use for the prediction.
406
+
407
+ Returns:
408
+ An iterator of GenerationChunks.
409
+ """
410
+ for chunk in sdk.nlp_predict_stream(
411
+ self.sambastudio_project_id,
412
+ self.sambastudio_endpoint_id,
413
+ self.sambastudio_api_key,
414
+ prompt,
415
+ tuning_params,
416
+ ):
417
+ if chunk['status_code'] != 200:
418
+ error = chunk.get('error')
419
+ if error:
420
+ optional_code = error.get('code')
421
+ optional_details = error.get('details')
422
+ optional_message = error.get('message')
423
+ raise ValueError(
424
+ f"Sambanova /complete call failed with status code "
425
+ f"{chunk['status_code']}.\n"
426
+ f"Message: {optional_message}\n"
427
+ f"Details: {optional_details}\n"
428
+ f"Code: {optional_code}\n"
429
+ )
430
+ else:
431
+ raise RuntimeError(
432
+ f"Sambanova /complete call failed with status code " f"{chunk['status_code']}." f"{chunk}."
433
+ )
434
+ if 'api/predict/nlp' in self.sambastudio_base_uri:
435
+ text = json.loads(chunk['data'])['stream_token']
436
+ elif 'api/v2/predict/generic' in self.sambastudio_base_uri:
437
+ text = chunk['result']['items'][0]['value']['stream_token']
438
+ elif 'api/predict/generic' in self.sambastudio_base_uri:
439
+ if len(chunk['result']['responses']) > 0:
440
+ text = chunk['result']['responses'][0]['stream_token']
441
+ else:
442
+ text = ''
443
+ else:
444
+ raise ValueError(f'handling of endpoint uri: {self.sambastudio_base_uri}' f'not implemented')
445
+ generated_chunk = GenerationChunk(text=text)
446
+ yield generated_chunk
447
+
448
+ def _stream(
449
+ self,
450
+ prompt: Union[List[str], str],
451
+ stop: Optional[List[str]] = None,
452
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
453
+ **kwargs: Any,
454
+ ) -> Iterator[GenerationChunk]:
455
+ """Call out to Sambanova's complete endpoint.
456
+
457
+ Args:
458
+ prompt: The prompt to pass into the model.
459
+ stop: Optional list of stop words to use when generating.
460
+
461
+ Returns:
462
+ The string generated by the model.
463
+ """
464
+ ss_endpoint = SSEndpointHandler(self.sambastudio_base_url, self.sambastudio_base_uri)
465
+ tuning_params = self._get_tuning_params(stop)
466
+ try:
467
+ if self.streaming:
468
+ for chunk in self._handle_nlp_predict_stream(ss_endpoint, prompt, tuning_params):
469
+ if run_manager:
470
+ run_manager.on_llm_new_token(chunk.text)
471
+ yield chunk
472
+ else:
473
+ return
474
+ except Exception as e:
475
+ # Handle any errors raised by the inference endpoint
476
+ raise ValueError(f'Error raised by the inference endpoint: {e}') from e
477
+
478
+ def _handle_stream_request(
479
+ self,
480
+ prompt: Union[List[str], str],
481
+ stop: Optional[List[str]],
482
+ run_manager: Optional[CallbackManagerForLLMRun],
483
+ kwargs: Dict[str, Any],
484
+ ) -> str:
485
+ """
486
+ Perform a streaming request to the LLM.
487
+
488
+ Args:
489
+ prompt: The prompt to generate from.
490
+ stop: Stop words to use when generating. Model output is cut off at the
491
+ first occurrence of any of the stop substrings.
492
+ run_manager: Callback manager for the run.
493
+ **kwargs: Additional keyword arguments. directly passed
494
+ to the sambastudio model in API call.
495
+
496
+ Returns:
497
+ The model output as a string.
498
+ """
499
+ completion = ''
500
+ for chunk in self._stream(prompt=prompt, stop=stop, run_manager=run_manager, **kwargs):
501
+ completion += chunk.text
502
+ return completion
503
+
504
+ def _call(
505
+ self,
506
+ prompt: Union[List[str], str],
507
+ stop: Optional[List[str]] = None,
508
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
509
+ **kwargs: Any,
510
+ ) -> str:
511
+ """Call out to Sambanova's complete endpoint.
512
+
513
+ Args:
514
+ prompt: The prompt to pass into the model.
515
+ stop: Optional list of stop words to use when generating.
516
+
517
+ Returns:
518
+ The string generated by the model.
519
+ """
520
+ if stop is not None:
521
+ raise Exception('stop not implemented')
522
+ try:
523
+ if self.streaming:
524
+ return self._handle_stream_request(prompt, stop, run_manager, kwargs)
525
+ return self._handle_completion_requests(prompt, stop)
526
+ except Exception as e:
527
+ # Handle any errors raised by the inference endpoint
528
+ raise ValueError(f'Error raised by the inference endpoint: {e}') from e
529
+
530
+
531
+ class SambaNovaCloud(LLM):
532
+ """
533
+ SambaNova Cloud large language models.
534
+
535
+ To use, you should have the environment variables
536
+ ``SAMBANOVA_URL`` set with your SambaNova Cloud URL.
537
+ ``SAMBANOVA_API_KEY`` set with your SambaNova Cloud API Key.
538
+
539
+ http://cloud.sambanova.ai/
540
+
541
+ Example:
542
+ .. code-block:: python
543
+
544
+ SambaNovaCloud(
545
+ sambanova_url = SambaNova cloud endpoint URL,
546
+ sambanova_api_key = set with your SambaNova cloud API key,
547
+ max_tokens = mas number of tokens to generate
548
+ stop_tokens = list of stop tokens
549
+ model = model name
550
+ )
551
+ """
552
+
553
+ sambanova_url: str = ''
554
+ """SambaNova Cloud Url"""
555
+
556
+ sambanova_api_key: str = ''
557
+ """SambaNova Cloud api key"""
558
+
559
+ max_tokens: int = 4000
560
+ """max tokens to generate"""
561
+
562
+ stop_tokens: list = ['<|eot_id|>']
563
+ """Stop tokens"""
564
+
565
+ model: str = 'llama3-8b'
566
+ """LLM model expert to use"""
567
+
568
+ temperature: float = 0.0
569
+ """model temperature"""
570
+
571
+ top_p: float = 0.0
572
+ """model top p"""
573
+
574
+ top_k: int = 1
575
+ """model top k"""
576
+
577
+ stream_api: bool = True
578
+ """use stream api"""
579
+
580
+ stream_options: dict = {'include_usage': True}
581
+ """stream options, include usage to get generation metrics"""
582
+
583
+ class Config:
584
+ """Configuration for this pydantic object."""
585
+
586
+ extra = 'forbid'#Extra.forbid
587
+
588
+ @classmethod
589
+ def is_lc_serializable(cls) -> bool:
590
+ return True
591
+
592
+ @property
593
+ def _identifying_params(self) -> Dict[str, Any]:
594
+ """Get the identifying parameters."""
595
+ return {
596
+ 'model': self.model,
597
+ 'max_tokens': self.max_tokens,
598
+ 'stop': self.stop_tokens,
599
+ 'temperature': self.temperature,
600
+ 'top_p': self.top_p,
601
+ 'top_k': self.top_k,
602
+ }
603
+
604
+ def invoke(
605
+ self,
606
+ input: LanguageModelInput,
607
+ config: Optional[RunnableConfig] = None,
608
+ *,
609
+ stop: Optional[List[str]] = None,
610
+ **kwargs: Any,
611
+ ) -> str:
612
+ config = ensure_config(config)
613
+
614
+ print('Invoking SambaNovaCloud with input:', input)
615
+ response = self.generate_prompt(
616
+ [self._convert_input(input)],
617
+ stop=stop,
618
+ callbacks=config.get("callbacks"),
619
+ tags=config.get("tags"),
620
+ metadata=config.get("metadata"),
621
+ run_name=config.get("run_name"),
622
+ run_id=config.pop("run_id", None),
623
+ **kwargs,
624
+ )
625
+ run_infos = response.run
626
+
627
+ if len(run_infos) > 1:
628
+ raise NotImplementedError('Multiple runs not supported')
629
+
630
+ run_id = run_infos[0].run_id
631
+
632
+ #print('Raw response:', response.run)
633
+ #print('Run ID:', run_id)
634
+ #print(USAGE_TRACKER)
635
+ #if run_id in USAGE_TRACKER:
636
+ # print('Usage:', USAGE_TRACKER[run_id])
637
+ #return response
638
+ return (
639
+ response
640
+ .generations[0][0]
641
+ .text
642
+ )
643
+
644
+ @property
645
+ def _llm_type(self) -> str:
646
+ """Return type of llm."""
647
+ return 'SambaNova Cloud'
648
+
649
+ @pre_init
650
+ def validate_environment(cls, values: Dict) -> Dict:
651
+ """Validate that api key and python package exists in environment."""
652
+ values['sambanova_url'] = get_from_dict_or_env(
653
+ values, 'sambanova_url', 'SAMBANOVA_URL', default='https://api.sambanova.ai/v1/chat/completions'
654
+ )
655
+ values['sambanova_api_key'] = get_from_dict_or_env(values, 'sambanova_api_key', 'SAMBANOVA_API_KEY')
656
+ return values
657
+
658
+ def _handle_nlp_predict_stream(
659
+ self,
660
+ prompt: Union[List[str], str],
661
+ stop: List[str],
662
+ ) -> Iterator[GenerationChunk]:
663
+ """
664
+ Perform a streaming request to the LLM.
665
+
666
+ Args:
667
+ prompt: The prompt to use for the prediction.
668
+ stop: list of stop tokens
669
+
670
+ Returns:
671
+ An iterator of GenerationChunks.
672
+ """
673
+ try:
674
+ import sseclient
675
+ except ImportError:
676
+ raise ImportError('could not import sseclient library' 'Please install it with `pip install sseclient-py`.')
677
+ try:
678
+ formatted_prompt = json.loads(prompt)
679
+ except:
680
+ formatted_prompt = [{'role': 'user', 'content': prompt}]
681
+
682
+ http_session = requests.Session()
683
+ if not stop:
684
+ stop = self.stop_tokens
685
+ data = {
686
+ 'messages': formatted_prompt,
687
+ 'max_tokens': self.max_tokens,
688
+ 'stop': stop,
689
+ 'model': self.model,
690
+ 'temperature': self.temperature,
691
+ 'top_p': self.top_p,
692
+ 'top_k': self.top_k,
693
+ 'stream': self.stream_api,
694
+ 'stream_options': self.stream_options,
695
+ }
696
+ # Streaming output
697
+ response = http_session.post(
698
+ self.sambanova_url,
699
+ headers={'Authorization': f'Bearer {self.sambanova_api_key}', 'Content-Type': 'application/json'},
700
+ json=data,
701
+ stream=True,
702
+ )
703
+
704
+ client = sseclient.SSEClient(response)
705
+ close_conn = False
706
+
707
+ print('Response:', response)
708
+
709
+ if response.status_code != 200:
710
+ raise RuntimeError(
711
+ f'Sambanova /complete call failed with status code ' f'{response.status_code}.' f'{response.text}.'
712
+ )
713
+
714
+ for event in client.events():
715
+ if event.event == 'error_event':
716
+ close_conn = True
717
+ #print('Event:', event.data)
718
+ chunk = {
719
+ 'event': event.event,
720
+ 'data': event.data,
721
+ 'status_code': response.status_code,
722
+ }
723
+
724
+ if chunk.get('error'):
725
+ raise RuntimeError(
726
+ f"Sambanova /complete call failed with status code " f"{chunk['status_code']}." f"{chunk}."
727
+ )
728
+
729
+ try:
730
+ # check if the response is a final event in that case event data response is '[DONE]'
731
+ #if 'usage' in chunk['data']:
732
+ # usage = json.loads(chunk['data'])
733
+ # print('Usage:', usage)
734
+ if chunk['data'] != '[DONE]':
735
+ data = json.loads(chunk['data'])
736
+ if data.get('error'):
737
+ raise RuntimeError(
738
+ f"Sambanova /complete call failed with status code " f"{chunk['status_code']}." f"{chunk}."
739
+ )
740
+ # check if the response is a final response with usage stats (not includes content)
741
+ if data.get('usage') is None:
742
+ # check is not "end of text" response
743
+ if data['choices'][0]['finish_reason'] is None:
744
+ text = data['choices'][0]['delta']['content']
745
+ generated_chunk = GenerationChunk(text=text)
746
+ yield generated_chunk
747
+ else:
748
+ #if data['id'] not in USAGE_TRACKER:
749
+ # USAGE_TRACKER[data['id']] = []
750
+ #USAGE_TRACKER[data['id']].append(data['usage'])
751
+ append_to_usage_tracker(data['usage'])
752
+ #print(f'Usage for id {data["id"]}:', data['usage'])
753
+ except Exception as e:
754
+ raise Exception(f'Error getting content chunk raw streamed response: {chunk}')
755
+
756
+ def _stream(
757
+ self,
758
+ prompt: Union[List[str], str],
759
+ stop: Optional[List[str]] = None,
760
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
761
+ **kwargs: Any,
762
+ ) -> Iterator[GenerationChunk]:
763
+ """Call out to Sambanova's complete endpoint.
764
+
765
+ Args:
766
+ prompt: The prompt to pass into the model.
767
+ stop: Optional list of stop words to use when generating.
768
+
769
+ Returns:
770
+ The string generated by the model.
771
+ """
772
+ try:
773
+ for chunk in self._handle_nlp_predict_stream(prompt, stop):
774
+ if run_manager:
775
+ run_manager.on_llm_new_token(chunk.text)
776
+ yield chunk
777
+ except Exception as e:
778
+ # Handle any errors raised by the inference endpoint
779
+ raise ValueError(f'Error raised by the inference endpoint: {e}') from e
780
+
781
+ def _handle_stream_request(
782
+ self,
783
+ prompt: Union[List[str], str],
784
+ stop: Optional[List[str]],
785
+ run_manager: Optional[CallbackManagerForLLMRun],
786
+ kwargs: Dict[str, Any],
787
+ ) -> str:
788
+ """
789
+ Perform a streaming request to the LLM.
790
+
791
+ Args:
792
+ prompt: The prompt to generate from.
793
+ stop: Stop words to use when generating. Model output is cut off at the
794
+ first occurrence of any of the stop substrings.
795
+ run_manager: Callback manager for the run.
796
+ **kwargs: Additional keyword arguments. directly passed
797
+ to the Sambanova Cloud model in API call.
798
+
799
+ Returns:
800
+ The model output as a string.
801
+ """
802
+ completion = ''
803
+ for chunk in self._stream(prompt=prompt, stop=stop, run_manager=run_manager, **kwargs):
804
+ completion += chunk.text
805
+ return completion
806
+
807
+ def _call(
808
+ self,
809
+ prompt: Union[List[str], str],
810
+ stop: Optional[List[str]] = None,
811
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
812
+ **kwargs: Any,
813
+ ) -> str:
814
+ """Call out to Sambanova's complete endpoint.
815
+
816
+ Args:
817
+ prompt: The prompt to pass into the model.
818
+ stop: Optional list of stop words to use when generating.
819
+
820
+ Returns:
821
+ The string generated by the model.
822
+ """
823
+ try:
824
+ return self._handle_stream_request(prompt, stop, run_manager, kwargs)
825
+ except Exception as e:
826
+ # Handle any errors raised by the inference endpoint
827
+ raise ValueError(f'Error raised by the inference endpoint: {e}') from e
toolformers/sambanova/utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+
3
+ USAGE_TRACKER = None
4
+
5
+ @contextmanager
6
+ def usage_tracker():
7
+ global USAGE_TRACKER
8
+ assert USAGE_TRACKER is None
9
+ USAGE_TRACKER = []
10
+ try:
11
+ yield
12
+ finally:
13
+ USAGE_TRACKER = None
14
+
15
+ def get_total_usage():
16
+ global USAGE_TRACKER
17
+
18
+ prompt_tokens = 0
19
+ completion_tokens = 0
20
+
21
+ for usage in USAGE_TRACKER:
22
+ prompt_tokens += usage['prompt_tokens']
23
+ completion_tokens += usage['completion_tokens']
24
+
25
+ return {
26
+ 'prompt_tokens': prompt_tokens,
27
+ 'completion_tokens': completion_tokens
28
+ }
29
+
30
+ def append_to_usage_tracker(usage):
31
+ global USAGE_TRACKER
32
+ USAGE_TRACKER.append(usage)
utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import hashlib
3
+
4
+ from functools import wraps
5
+ from typing import Callable, Any
6
+ from inspect import signature, Parameter, Signature
7
+
8
+
9
+ def extract_substring(text, start_tag, end_tag):
10
+ start_position = text.lower().find(start_tag.lower())
11
+ end_position = text.lower().find(end_tag.lower())
12
+
13
+ if start_position == -1 or end_position == -1:
14
+ return None
15
+
16
+ return text[start_position + len(start_tag):end_position].strip()
17
+
18
+ def compute_hash(s):
19
+ # Hash a string using SHA-1 and return the base64 encoded result
20
+
21
+ m = hashlib.sha1()
22
+ m.update(s.encode())
23
+
24
+ b = m.digest()
25
+
26
+ return base64.b64encode(b).decode('ascii')
27
+
28
+ def add_params_and_annotations(name: str, description: str, params: dict, return_type: Any):
29
+ """
30
+ A decorator to add parameters and a return type annotation to a function.
31
+
32
+ :param params: A dictionary where the keys are parameter names and values are their types.
33
+ :param return_type: The return type to add to the function's signature.
34
+ """
35
+ def decorator(func: Callable):
36
+ # Create new parameters based on the provided params dict
37
+ new_params = [
38
+ Parameter(name, Parameter.POSITIONAL_OR_KEYWORD, annotation=type_)
39
+ for name, (type_, _) in params.items()
40
+ ]
41
+
42
+ # Get the existing signature and parameters
43
+ original_sig = signature(func)
44
+ original_params = list(original_sig.parameters.values())
45
+
46
+ # Combine new parameters with the existing ones
47
+ combined_params = new_params# + original_params
48
+
49
+ # Create a new signature with updated parameters and return annotation
50
+ new_sig = Signature(parameters=combined_params, return_annotation=return_type)
51
+
52
+ # Define the wrapper function
53
+ @wraps(func)
54
+ def wrapper(*args, **kwargs):
55
+ return func(*args, **kwargs)
56
+
57
+ docstring = description
58
+
59
+ if len(params) > 0:
60
+ docstring += '\n\nArgs:'
61
+ for param_name, (type_, param_description) in params.items():
62
+ docstring += f'\n {param_name}: {param_description}'
63
+
64
+ docstring += f'\n\nReturns:\n {return_type.__name__}'
65
+
66
+ print('Input params:', params)
67
+
68
+
69
+ # Set the new signature on the wrapper
70
+ wrapper.__name__ = name
71
+ wrapper.__signature__ = new_sig
72
+ wrapper.__annotations__.update({ k: v[0] for k, v in params.items() })
73
+ wrapper.__annotations__['return'] = return_type
74
+ #wrapper.__annotations__['name'] = str
75
+ wrapper.__doc__ = docstring
76
+
77
+ print(docstring)
78
+
79
+ return wrapper
80
+ return decorator