venkat-srinivasan-nexusflow commited on
Commit
fb065be
1 Parent(s): f463518

Create weather_with_chat.py

Browse files
Files changed (1) hide show
  1. example/weather_with_chat.py +246 -0
example/weather_with_chat.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import json
3
+ from typing import List, Dict, Any, Optional
4
+ from openai import OpenAI
5
+ """
6
+ EXAMPLE OUTPUT:
7
+
8
+ ****************************************
9
+ RUNNING QUERY: What's the weather for Paris, TX in fahrenheit?
10
+
11
+ Agent Issued Step 1
12
+ ----------------------------------------
13
+
14
+ Agent Issued Step 2
15
+ ----------------------------------------
16
+
17
+ Agent Issued Step 3
18
+ ----------------------------------------
19
+ AGENT MESSAGE: The current weather in Paris, TX is 85 degrees fahrenheit. It is partly cloudy, with highs in the 90s.
20
+ Conversation Complete
21
+
22
+
23
+ ****************************************
24
+ RUNNING QUERY: Who won the most recent PGA?
25
+
26
+ Agent Issued Step 1
27
+ ----------------------------------------
28
+
29
+ Agent Issued Step 2
30
+ ----------------------------------------
31
+ AGENT MESSAGE: I'm sorry, but I don't have the ability to provide sports information. I can help you with weather and location data. Is there anything else I can assist you with?
32
+ Conversation Complete
33
+ """
34
+
35
+ @dataclass
36
+ class WeatherConfig:
37
+ """Configuration for OpenAI and API settings"""
38
+ api_key: str = "" # The VLLM api_key
39
+ api_base: str = "" # The VLLM api_base URL
40
+ model: Optional[str] = None
41
+ max_steps: int = 5
42
+
43
+ class WeatherTools:
44
+ """Collection of available tools/functions for the weather agent"""
45
+
46
+ @staticmethod
47
+ def get_current_weather(latitude: List[float], longitude: List[float], unit: str) -> str:
48
+ """Get weather for given coordinates"""
49
+ # We are mocking the weather here, but in the real world, you will submit a request here.
50
+ return f"The weather is 85 degrees {unit}. It is partly cloudy, with highs in the 90's."
51
+
52
+ @staticmethod
53
+ def get_geo_coordinates(city: str, state: str) -> str:
54
+ """Get coordinates for a given city"""
55
+ coordinates = {
56
+ "Dallas": {"TX": (32.7767, -96.7970)},
57
+ "San Francisco": {"CA": (37.7749, -122.4194)},
58
+ "Paris": {"TX": (33.6609, 95.5555)}
59
+ }
60
+ lat, lon = coordinates.get(city, {}).get(state, (0, 0))
61
+ # We are mocking the weather here, but in the real world, you will submit a request here.
62
+ return f"The coordinates for {city}, {state} are: latitude {lat}, longitude {lon}"
63
+
64
+ @staticmethod
65
+ def no_relevant_function(user_query_span : str) -> str:
66
+ return "No relevant function for your request was found. We will stop here."
67
+
68
+ @staticmethod
69
+ def chat(chat_string : str):
70
+ print ("AGENT MESSAGE: ", chat_string)
71
+
72
+ class ToolRegistry:
73
+ """Registry of available tools and their schemas"""
74
+
75
+ @property
76
+ def available_functions(self) -> Dict[str, callable]:
77
+ return {
78
+ "get_current_weather": WeatherTools.get_current_weather,
79
+ "get_geo_coordinates": WeatherTools.get_geo_coordinates,
80
+ "no_relevant_function" : WeatherTools.no_relevant_function,
81
+ "chat" : WeatherTools.chat
82
+ }
83
+
84
+ @property
85
+ def tool_schemas(self) -> List[Dict[str, Any]]:
86
+ return [
87
+ {
88
+ "type": "function",
89
+ "function": {
90
+ "name": "get_current_weather",
91
+ "description": "Get the current weather in a given location. Use exact coordinates.",
92
+ "parameters": {
93
+ "type": "object",
94
+ "properties": {
95
+ "latitude": {"type": "array", "description": "The latitude for the city."},
96
+ "longitude": {"type": "array", "description": "The longitude for the city."},
97
+ "unit": {
98
+ "type": "string",
99
+ "description": "The unit to fetch the temperature in",
100
+ "enum": ["celsius", "fahrenheit"]
101
+ }
102
+ },
103
+ "required": ["latitude", "longitude", "unit"]
104
+ }
105
+ }
106
+ },
107
+ {
108
+ "type": "function",
109
+ "function": {
110
+ "name": "get_geo_coordinates",
111
+ "description": "Get the latitude and longitude for a given city",
112
+ "parameters": {
113
+ "type": "object",
114
+ "properties": {
115
+ "city": {"type": "string", "description": "The city to find coordinates for"},
116
+ "state": {"type": "string", "description": "The two-letter state abbreviation"}
117
+ },
118
+ "required": ["city", "state"]
119
+ }
120
+ }
121
+ },
122
+ {
123
+ "type": "function",
124
+ "function" : {
125
+ "name": "no_relevant_function",
126
+ "description": "Call this when no other provided function can be called to answer the user query.",
127
+ "parameters": {
128
+ "type": "object",
129
+ "properties": {
130
+ "user_query_span": {
131
+ "type": "string",
132
+ "description": "The part of the user_query that cannot be answered by any other function calls."
133
+ }
134
+ },
135
+ "required": ["user_query_span"]
136
+ }
137
+ }
138
+ },
139
+ {
140
+ "type": "function",
141
+ "function": {
142
+ "name": "chat",
143
+ "description": "Call this tool when you want to chat with the user. The user won't see anything except for whatever you pass into this function.",
144
+ "parameters": {
145
+ "type": "object",
146
+ "properties": {
147
+ "chat_string": {
148
+ "type": "string",
149
+ "description": "The string to send to the user to chat back to them.",
150
+ }
151
+ },
152
+ "required": ["chat_string"],
153
+ },
154
+ },
155
+ },
156
+ ]
157
+
158
+ class WeatherAgent:
159
+ """Main agent class that handles the conversation and tool execution"""
160
+
161
+ def __init__(self, config: WeatherConfig):
162
+ self.config = config
163
+ self.client = OpenAI(api_key=config.api_key, base_url=config.api_base)
164
+ self.tools = ToolRegistry()
165
+ self.messages = []
166
+
167
+ if not config.model:
168
+ models = self.client.models.list()
169
+ self.config.model = models.data[0].id
170
+
171
+ def _serialize_tool_call(self, tool_call) -> Dict[str, Any]:
172
+ """Convert tool call to serializable format"""
173
+ return {
174
+ "id": tool_call.id,
175
+ "type": tool_call.type,
176
+ "function": {
177
+ "name": tool_call.function.name,
178
+ "arguments": tool_call.function.arguments
179
+ }
180
+ }
181
+
182
+ def process_tool_calls(self, message) -> None:
183
+ """Process and execute tool calls from assistant"""
184
+ for tool_call in message.tool_calls:
185
+ function_name = tool_call.function.name
186
+ function_args = json.loads(tool_call.function.arguments)
187
+
188
+ function_response = self.tools.available_functions[function_name](**function_args)
189
+
190
+ self.messages.append({
191
+ "role": "tool",
192
+ "content": json.dumps(function_response),
193
+ "tool_call_id": tool_call.id,
194
+ "name": function_name
195
+ })
196
+
197
+ def run_conversation(self, initial_query: str) -> None:
198
+ """Run the main conversation loop"""
199
+ self.messages = [
200
+ {"role" : "system", "content" : "Make sure to use the chat() function to provide the final answer to the user."},
201
+ {"role": "user", "content": initial_query}]
202
+
203
+ print ("\n" * 5)
204
+ print ("*" * 40)
205
+ print (f"RUNNING QUERY: {initial_query}")
206
+
207
+ for step in range(self.config.max_steps):
208
+
209
+ response = self.client.chat.completions.create(
210
+ messages=self.messages,
211
+ model=self.config.model,
212
+ tools=self.tools.tool_schemas,
213
+ temperature=0.0,
214
+ )
215
+
216
+ message = response.choices[0].message
217
+
218
+ if not message.tool_calls:
219
+ print("Conversation Complete")
220
+ break
221
+
222
+ print(f"\nAgent Issued Step {step + 1}")
223
+ print("-" * 40)
224
+
225
+ self.messages.append({
226
+ "role": "assistant",
227
+ "content": json.dumps(message.content),
228
+ "tool_calls": [self._serialize_tool_call(tc) for tc in message.tool_calls]
229
+ })
230
+
231
+ self.process_tool_calls(message)
232
+
233
+ if step >= self.config.max_steps - 1:
234
+ print("Maximum steps reached")
235
+
236
+ def main():
237
+ # Example usage
238
+ config = WeatherConfig()
239
+ agent = WeatherAgent(config)
240
+ agent.run_conversation("What's the weather for Paris, TX in fahrenheit?")
241
+
242
+ # Example OOD usage
243
+ agent.run_conversation("Who won the most recent PGA?")
244
+
245
+ if __name__ == "__main__":
246
+ main()