Upload openai.py
Browse files- src/utils/openai.py +338 -0
src/utils/openai.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2024, Daily
|
3 |
+
#
|
4 |
+
# SPDX-License-Identifier: BSD 2-Clause License
|
5 |
+
#
|
6 |
+
|
7 |
+
import aiohttp
|
8 |
+
import base64
|
9 |
+
import io
|
10 |
+
import json
|
11 |
+
|
12 |
+
from typing import AsyncGenerator, List, Literal
|
13 |
+
|
14 |
+
from loguru import logger
|
15 |
+
from PIL import Image
|
16 |
+
|
17 |
+
from pipecat.frames.frames import (
|
18 |
+
AudioRawFrame,
|
19 |
+
ErrorFrame,
|
20 |
+
Frame,
|
21 |
+
LLMFullResponseEndFrame,
|
22 |
+
LLMFullResponseStartFrame,
|
23 |
+
LLMMessagesFrame,
|
24 |
+
LLMResponseEndFrame,
|
25 |
+
LLMResponseStartFrame,
|
26 |
+
TextFrame,
|
27 |
+
URLImageRawFrame,
|
28 |
+
VisionImageRawFrame
|
29 |
+
)
|
30 |
+
from pipecat.processors.aggregators.openai_llm_context import (
|
31 |
+
OpenAILLMContext,
|
32 |
+
OpenAILLMContextFrame
|
33 |
+
)
|
34 |
+
from pipecat.processors.frame_processor import FrameDirection
|
35 |
+
from pipecat.services.ai_services import (
|
36 |
+
ImageGenService,
|
37 |
+
LLMService,
|
38 |
+
TTSService
|
39 |
+
)
|
40 |
+
|
41 |
+
try:
|
42 |
+
from openai import AsyncOpenAI, AsyncStream, BadRequestError
|
43 |
+
from openai.types.chat import (
|
44 |
+
ChatCompletionChunk,
|
45 |
+
ChatCompletionFunctionMessageParam,
|
46 |
+
ChatCompletionMessageParam,
|
47 |
+
ChatCompletionToolParam
|
48 |
+
)
|
49 |
+
except ModuleNotFoundError as e:
|
50 |
+
logger.error(f"Exception: {e}")
|
51 |
+
logger.error(
|
52 |
+
"In order to use OpenAI, you need to `pip install pipecat-ai[openai]`. Also, set `OPENAI_API_KEY` environment variable.")
|
53 |
+
raise Exception(f"Missing module: {e}")
|
54 |
+
|
55 |
+
|
56 |
+
class OpenAIUnhandledFunctionException(Exception):
|
57 |
+
pass
|
58 |
+
|
59 |
+
|
60 |
+
class BaseOpenAILLMService(LLMService):
|
61 |
+
"""This is the base for all services that use the AsyncOpenAI client.
|
62 |
+
|
63 |
+
This service consumes OpenAILLMContextFrame frames, which contain a reference
|
64 |
+
to an OpenAILLMContext frame. The OpenAILLMContext object defines the context
|
65 |
+
sent to the LLM for a completion. This includes user, assistant and system messages
|
66 |
+
as well as tool choices and the tool, which is used if requesting function
|
67 |
+
calls from the LLM.
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self, *, model: str, api_key=None, base_url=None, **kwargs):
|
71 |
+
super().__init__(**kwargs)
|
72 |
+
self._model: str = model
|
73 |
+
self._client = self.create_client(api_key=api_key, base_url=base_url, **kwargs)
|
74 |
+
|
75 |
+
def create_client(self, api_key=None, base_url=None, **kwargs):
|
76 |
+
return AsyncOpenAI(api_key=api_key, base_url=base_url)
|
77 |
+
|
78 |
+
def can_generate_metrics(self) -> bool:
|
79 |
+
return True
|
80 |
+
|
81 |
+
async def get_chat_completions(
|
82 |
+
self,
|
83 |
+
context: OpenAILLMContext,
|
84 |
+
messages: List[ChatCompletionMessageParam]) -> AsyncStream[ChatCompletionChunk]:
|
85 |
+
chunks = await self._client.chat.completions.create(
|
86 |
+
model=self._model,
|
87 |
+
stream=True,
|
88 |
+
messages=messages,
|
89 |
+
tools=context.tools,
|
90 |
+
tool_choice=context.tool_choice,
|
91 |
+
)
|
92 |
+
return chunks
|
93 |
+
|
94 |
+
async def _stream_chat_completions(
|
95 |
+
self, context: OpenAILLMContext) -> AsyncStream[ChatCompletionChunk]:
|
96 |
+
logger.debug(f"Generating chat: {context.get_messages_json()}")
|
97 |
+
|
98 |
+
messages: List[ChatCompletionMessageParam] = context.get_messages()
|
99 |
+
|
100 |
+
# base64 encode any images
|
101 |
+
for message in messages:
|
102 |
+
if message.get("mime_type") == "image/jpeg":
|
103 |
+
encoded_image = base64.b64encode(message["data"].getvalue()).decode("utf-8")
|
104 |
+
text = message["content"]
|
105 |
+
message["content"] = [
|
106 |
+
{"type": "text", "text": text},
|
107 |
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}
|
108 |
+
]
|
109 |
+
del message["data"]
|
110 |
+
del message["mime_type"]
|
111 |
+
|
112 |
+
chunks = await self.get_chat_completions(context, messages)
|
113 |
+
|
114 |
+
return chunks
|
115 |
+
|
116 |
+
async def _process_context(self, context: OpenAILLMContext):
|
117 |
+
function_name = ""
|
118 |
+
arguments = ""
|
119 |
+
tool_call_id = ""
|
120 |
+
|
121 |
+
await self.start_ttfb_metrics()
|
122 |
+
|
123 |
+
chunk_stream: AsyncStream[ChatCompletionChunk] = (
|
124 |
+
await self._stream_chat_completions(context)
|
125 |
+
)
|
126 |
+
|
127 |
+
async for chunk in chunk_stream:
|
128 |
+
if len(chunk.choices) == 0:
|
129 |
+
continue
|
130 |
+
|
131 |
+
await self.stop_ttfb_metrics()
|
132 |
+
|
133 |
+
if chunk.choices[0].delta.tool_calls:
|
134 |
+
# We're streaming the LLM response to enable the fastest response times.
|
135 |
+
# For text, we just yield each chunk as we receive it and count on consumers
|
136 |
+
# to do whatever coalescing they need (eg. to pass full sentences to TTS)
|
137 |
+
#
|
138 |
+
# If the LLM is a function call, we'll do some coalescing here.
|
139 |
+
# If the response contains a function name, we'll yield a frame to tell consumers
|
140 |
+
# that they can start preparing to call the function with that name.
|
141 |
+
# We accumulate all the arguments for the rest of the streamed response, then when
|
142 |
+
# the response is done, we package up all the arguments and the function name and
|
143 |
+
# yield a frame containing the function name and the arguments.
|
144 |
+
|
145 |
+
tool_call = chunk.choices[0].delta.tool_calls[0]
|
146 |
+
if tool_call.function and tool_call.function.name:
|
147 |
+
function_name += tool_call.function.name
|
148 |
+
tool_call_id = tool_call.id
|
149 |
+
await self.call_start_function(function_name)
|
150 |
+
if tool_call.function and tool_call.function.arguments:
|
151 |
+
# Keep iterating through the response to collect all the argument fragments
|
152 |
+
arguments += tool_call.function.arguments
|
153 |
+
elif chunk.choices[0].delta.content:
|
154 |
+
await self.push_frame(LLMResponseStartFrame())
|
155 |
+
await self.push_frame(TextFrame(chunk.choices[0].delta.content))
|
156 |
+
await self.push_frame(LLMResponseEndFrame())
|
157 |
+
|
158 |
+
# if we got a function name and arguments, check to see if it's a function with
|
159 |
+
# a registered handler. If so, run the registered callback, save the result to
|
160 |
+
# the context, and re-prompt to get a chat answer. If we don't have a registered
|
161 |
+
# handler, raise an exception.
|
162 |
+
if function_name and arguments:
|
163 |
+
if self.has_function(function_name):
|
164 |
+
await self._handle_function_call(context, tool_call_id, function_name, arguments)
|
165 |
+
else:
|
166 |
+
raise OpenAIUnhandledFunctionException(
|
167 |
+
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function.")
|
168 |
+
|
169 |
+
async def _handle_function_call(
|
170 |
+
self,
|
171 |
+
context,
|
172 |
+
tool_call_id,
|
173 |
+
function_name,
|
174 |
+
arguments
|
175 |
+
):
|
176 |
+
arguments = json.loads(arguments)
|
177 |
+
result = await self.call_function(function_name, arguments)
|
178 |
+
arguments = json.dumps(arguments)
|
179 |
+
if isinstance(result, (str, dict)):
|
180 |
+
# Handle it in "full magic mode"
|
181 |
+
tool_call = ChatCompletionFunctionMessageParam({
|
182 |
+
"role": "assistant",
|
183 |
+
"tool_calls": [
|
184 |
+
{
|
185 |
+
"id": tool_call_id,
|
186 |
+
"function": {
|
187 |
+
"arguments": arguments,
|
188 |
+
"name": function_name
|
189 |
+
},
|
190 |
+
"type": "function"
|
191 |
+
}
|
192 |
+
]
|
193 |
+
|
194 |
+
})
|
195 |
+
context.add_message(tool_call)
|
196 |
+
if isinstance(result, dict):
|
197 |
+
result = json.dumps(result)
|
198 |
+
tool_result = ChatCompletionToolParam({
|
199 |
+
"tool_call_id": tool_call_id,
|
200 |
+
"role": "tool",
|
201 |
+
"content": result
|
202 |
+
})
|
203 |
+
context.add_message(tool_result)
|
204 |
+
# re-prompt to get a human answer
|
205 |
+
await self._process_context(context)
|
206 |
+
elif isinstance(result, list):
|
207 |
+
# reduced magic
|
208 |
+
for msg in result:
|
209 |
+
context.add_message(msg)
|
210 |
+
await self._process_context(context)
|
211 |
+
elif isinstance(result, type(None)):
|
212 |
+
pass
|
213 |
+
else:
|
214 |
+
raise TypeError(f"Unknown return type from function callback: {type(result)}")
|
215 |
+
|
216 |
+
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
217 |
+
await super().process_frame(frame, direction)
|
218 |
+
|
219 |
+
context = None
|
220 |
+
if isinstance(frame, OpenAILLMContextFrame):
|
221 |
+
context: OpenAILLMContext = frame.context
|
222 |
+
elif isinstance(frame, LLMMessagesFrame):
|
223 |
+
context = OpenAILLMContext.from_messages(frame.messages)
|
224 |
+
elif isinstance(frame, VisionImageRawFrame):
|
225 |
+
context = OpenAILLMContext.from_image_frame(frame)
|
226 |
+
else:
|
227 |
+
await self.push_frame(frame, direction)
|
228 |
+
|
229 |
+
if context:
|
230 |
+
await self.push_frame(LLMFullResponseStartFrame())
|
231 |
+
await self.start_processing_metrics()
|
232 |
+
await self._process_context(context)
|
233 |
+
await self.stop_processing_metrics()
|
234 |
+
await self.push_frame(LLMFullResponseEndFrame())
|
235 |
+
|
236 |
+
|
237 |
+
class OpenAILLMService(BaseOpenAILLMService):
|
238 |
+
|
239 |
+
def __init__(self, *, model: str = "gpt-4o", **kwargs):
|
240 |
+
super().__init__(model=model, **kwargs)
|
241 |
+
|
242 |
+
|
243 |
+
class OpenAIImageGenService(ImageGenService):
|
244 |
+
|
245 |
+
def __init__(
|
246 |
+
self,
|
247 |
+
*,
|
248 |
+
image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
|
249 |
+
aiohttp_session: aiohttp.ClientSession,
|
250 |
+
api_key: str,
|
251 |
+
model: str = "dall-e-3",
|
252 |
+
):
|
253 |
+
super().__init__()
|
254 |
+
self._model = model
|
255 |
+
self._image_size = image_size
|
256 |
+
self._client = AsyncOpenAI(api_key=api_key)
|
257 |
+
self._aiohttp_session = aiohttp_session
|
258 |
+
|
259 |
+
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
|
260 |
+
logger.debug(f"Generating image from prompt: {prompt}")
|
261 |
+
|
262 |
+
image = await self._client.images.generate(
|
263 |
+
prompt=prompt,
|
264 |
+
model=self._model,
|
265 |
+
n=1,
|
266 |
+
size=self._image_size
|
267 |
+
)
|
268 |
+
|
269 |
+
image_url = image.data[0].url
|
270 |
+
|
271 |
+
if not image_url:
|
272 |
+
logger.error(f"{self} No image provided in response: {image}")
|
273 |
+
yield ErrorFrame("Image generation failed")
|
274 |
+
return
|
275 |
+
|
276 |
+
# Load the image from the url
|
277 |
+
async with self._aiohttp_session.get(image_url) as response:
|
278 |
+
image_stream = io.BytesIO(await response.content.read())
|
279 |
+
image = Image.open(image_stream)
|
280 |
+
frame = URLImageRawFrame(image_url, image.tobytes(), image.size, image.format)
|
281 |
+
yield frame
|
282 |
+
|
283 |
+
|
284 |
+
class OpenAITTSService(TTSService):
|
285 |
+
"""This service uses the OpenAI TTS API to generate audio from text.
|
286 |
+
The returned audio is PCM encoded at 24kHz. When using the DailyTransport, set the sample rate in the DailyParams accordingly:
|
287 |
+
```
|
288 |
+
DailyParams(
|
289 |
+
audio_out_enabled=True,
|
290 |
+
audio_out_sample_rate=24_000,
|
291 |
+
)
|
292 |
+
```
|
293 |
+
"""
|
294 |
+
|
295 |
+
def __init__(
|
296 |
+
self,
|
297 |
+
*,
|
298 |
+
api_key: str | None = None,
|
299 |
+
base_url: str | None = None,
|
300 |
+
sample_rate: int = 24_000,
|
301 |
+
voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = "alloy",
|
302 |
+
model: Literal["tts-1", "tts-1-hd"] = "tts-1",
|
303 |
+
**kwargs):
|
304 |
+
super().__init__(**kwargs)
|
305 |
+
|
306 |
+
self._voice = voice
|
307 |
+
self._model = model
|
308 |
+
self.sample_rate=sample_rate
|
309 |
+
self._client = AsyncOpenAI(api_key=api_key,base_url=base_url)
|
310 |
+
|
311 |
+
def can_generate_metrics(self) -> bool:
|
312 |
+
return True
|
313 |
+
|
314 |
+
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
315 |
+
logger.debug(f"Generating TTS: [{text}]")
|
316 |
+
|
317 |
+
try:
|
318 |
+
await self.start_ttfb_metrics()
|
319 |
+
|
320 |
+
async with self._client.audio.speech.with_streaming_response.create(
|
321 |
+
input=text,
|
322 |
+
model=self._model,
|
323 |
+
voice=self._voice,
|
324 |
+
response_format="pcm",
|
325 |
+
) as r:
|
326 |
+
if r.status_code != 200:
|
327 |
+
error = await r.text()
|
328 |
+
logger.error(
|
329 |
+
f"{self} error getting audio (status: {r.status_code}, error: {error})")
|
330 |
+
yield ErrorFrame(f"Error getting audio (status: {r.status_code}, error: {error})")
|
331 |
+
return
|
332 |
+
async for chunk in r.iter_bytes(8192):
|
333 |
+
if len(chunk) > 0:
|
334 |
+
await self.stop_ttfb_metrics()
|
335 |
+
frame = AudioRawFrame(chunk, self.sample_rate, 1)
|
336 |
+
yield frame
|
337 |
+
except BadRequestError as e:
|
338 |
+
logger.exception(f"{self} error generating TTS: {e}")
|