lucy1118 commited on
Commit
405fd0e
·
verified ·
1 Parent(s): 23bf47f

Upload openai.py

Browse files
Files changed (1) hide show
  1. src/utils/openai.py +338 -0
src/utils/openai.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2024, Daily
3
+ #
4
+ # SPDX-License-Identifier: BSD 2-Clause License
5
+ #
6
+
7
+ import aiohttp
8
+ import base64
9
+ import io
10
+ import json
11
+
12
+ from typing import AsyncGenerator, List, Literal
13
+
14
+ from loguru import logger
15
+ from PIL import Image
16
+
17
+ from pipecat.frames.frames import (
18
+ AudioRawFrame,
19
+ ErrorFrame,
20
+ Frame,
21
+ LLMFullResponseEndFrame,
22
+ LLMFullResponseStartFrame,
23
+ LLMMessagesFrame,
24
+ LLMResponseEndFrame,
25
+ LLMResponseStartFrame,
26
+ TextFrame,
27
+ URLImageRawFrame,
28
+ VisionImageRawFrame
29
+ )
30
+ from pipecat.processors.aggregators.openai_llm_context import (
31
+ OpenAILLMContext,
32
+ OpenAILLMContextFrame
33
+ )
34
+ from pipecat.processors.frame_processor import FrameDirection
35
+ from pipecat.services.ai_services import (
36
+ ImageGenService,
37
+ LLMService,
38
+ TTSService
39
+ )
40
+
41
+ try:
42
+ from openai import AsyncOpenAI, AsyncStream, BadRequestError
43
+ from openai.types.chat import (
44
+ ChatCompletionChunk,
45
+ ChatCompletionFunctionMessageParam,
46
+ ChatCompletionMessageParam,
47
+ ChatCompletionToolParam
48
+ )
49
+ except ModuleNotFoundError as e:
50
+ logger.error(f"Exception: {e}")
51
+ logger.error(
52
+ "In order to use OpenAI, you need to `pip install pipecat-ai[openai]`. Also, set `OPENAI_API_KEY` environment variable.")
53
+ raise Exception(f"Missing module: {e}")
54
+
55
+
56
+ class OpenAIUnhandledFunctionException(Exception):
57
+ pass
58
+
59
+
60
+ class BaseOpenAILLMService(LLMService):
61
+ """This is the base for all services that use the AsyncOpenAI client.
62
+
63
+ This service consumes OpenAILLMContextFrame frames, which contain a reference
64
+ to an OpenAILLMContext frame. The OpenAILLMContext object defines the context
65
+ sent to the LLM for a completion. This includes user, assistant and system messages
66
+ as well as tool choices and the tool, which is used if requesting function
67
+ calls from the LLM.
68
+ """
69
+
70
+ def __init__(self, *, model: str, api_key=None, base_url=None, **kwargs):
71
+ super().__init__(**kwargs)
72
+ self._model: str = model
73
+ self._client = self.create_client(api_key=api_key, base_url=base_url, **kwargs)
74
+
75
+ def create_client(self, api_key=None, base_url=None, **kwargs):
76
+ return AsyncOpenAI(api_key=api_key, base_url=base_url)
77
+
78
+ def can_generate_metrics(self) -> bool:
79
+ return True
80
+
81
+ async def get_chat_completions(
82
+ self,
83
+ context: OpenAILLMContext,
84
+ messages: List[ChatCompletionMessageParam]) -> AsyncStream[ChatCompletionChunk]:
85
+ chunks = await self._client.chat.completions.create(
86
+ model=self._model,
87
+ stream=True,
88
+ messages=messages,
89
+ tools=context.tools,
90
+ tool_choice=context.tool_choice,
91
+ )
92
+ return chunks
93
+
94
+ async def _stream_chat_completions(
95
+ self, context: OpenAILLMContext) -> AsyncStream[ChatCompletionChunk]:
96
+ logger.debug(f"Generating chat: {context.get_messages_json()}")
97
+
98
+ messages: List[ChatCompletionMessageParam] = context.get_messages()
99
+
100
+ # base64 encode any images
101
+ for message in messages:
102
+ if message.get("mime_type") == "image/jpeg":
103
+ encoded_image = base64.b64encode(message["data"].getvalue()).decode("utf-8")
104
+ text = message["content"]
105
+ message["content"] = [
106
+ {"type": "text", "text": text},
107
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}
108
+ ]
109
+ del message["data"]
110
+ del message["mime_type"]
111
+
112
+ chunks = await self.get_chat_completions(context, messages)
113
+
114
+ return chunks
115
+
116
+ async def _process_context(self, context: OpenAILLMContext):
117
+ function_name = ""
118
+ arguments = ""
119
+ tool_call_id = ""
120
+
121
+ await self.start_ttfb_metrics()
122
+
123
+ chunk_stream: AsyncStream[ChatCompletionChunk] = (
124
+ await self._stream_chat_completions(context)
125
+ )
126
+
127
+ async for chunk in chunk_stream:
128
+ if len(chunk.choices) == 0:
129
+ continue
130
+
131
+ await self.stop_ttfb_metrics()
132
+
133
+ if chunk.choices[0].delta.tool_calls:
134
+ # We're streaming the LLM response to enable the fastest response times.
135
+ # For text, we just yield each chunk as we receive it and count on consumers
136
+ # to do whatever coalescing they need (eg. to pass full sentences to TTS)
137
+ #
138
+ # If the LLM is a function call, we'll do some coalescing here.
139
+ # If the response contains a function name, we'll yield a frame to tell consumers
140
+ # that they can start preparing to call the function with that name.
141
+ # We accumulate all the arguments for the rest of the streamed response, then when
142
+ # the response is done, we package up all the arguments and the function name and
143
+ # yield a frame containing the function name and the arguments.
144
+
145
+ tool_call = chunk.choices[0].delta.tool_calls[0]
146
+ if tool_call.function and tool_call.function.name:
147
+ function_name += tool_call.function.name
148
+ tool_call_id = tool_call.id
149
+ await self.call_start_function(function_name)
150
+ if tool_call.function and tool_call.function.arguments:
151
+ # Keep iterating through the response to collect all the argument fragments
152
+ arguments += tool_call.function.arguments
153
+ elif chunk.choices[0].delta.content:
154
+ await self.push_frame(LLMResponseStartFrame())
155
+ await self.push_frame(TextFrame(chunk.choices[0].delta.content))
156
+ await self.push_frame(LLMResponseEndFrame())
157
+
158
+ # if we got a function name and arguments, check to see if it's a function with
159
+ # a registered handler. If so, run the registered callback, save the result to
160
+ # the context, and re-prompt to get a chat answer. If we don't have a registered
161
+ # handler, raise an exception.
162
+ if function_name and arguments:
163
+ if self.has_function(function_name):
164
+ await self._handle_function_call(context, tool_call_id, function_name, arguments)
165
+ else:
166
+ raise OpenAIUnhandledFunctionException(
167
+ f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function.")
168
+
169
+ async def _handle_function_call(
170
+ self,
171
+ context,
172
+ tool_call_id,
173
+ function_name,
174
+ arguments
175
+ ):
176
+ arguments = json.loads(arguments)
177
+ result = await self.call_function(function_name, arguments)
178
+ arguments = json.dumps(arguments)
179
+ if isinstance(result, (str, dict)):
180
+ # Handle it in "full magic mode"
181
+ tool_call = ChatCompletionFunctionMessageParam({
182
+ "role": "assistant",
183
+ "tool_calls": [
184
+ {
185
+ "id": tool_call_id,
186
+ "function": {
187
+ "arguments": arguments,
188
+ "name": function_name
189
+ },
190
+ "type": "function"
191
+ }
192
+ ]
193
+
194
+ })
195
+ context.add_message(tool_call)
196
+ if isinstance(result, dict):
197
+ result = json.dumps(result)
198
+ tool_result = ChatCompletionToolParam({
199
+ "tool_call_id": tool_call_id,
200
+ "role": "tool",
201
+ "content": result
202
+ })
203
+ context.add_message(tool_result)
204
+ # re-prompt to get a human answer
205
+ await self._process_context(context)
206
+ elif isinstance(result, list):
207
+ # reduced magic
208
+ for msg in result:
209
+ context.add_message(msg)
210
+ await self._process_context(context)
211
+ elif isinstance(result, type(None)):
212
+ pass
213
+ else:
214
+ raise TypeError(f"Unknown return type from function callback: {type(result)}")
215
+
216
+ async def process_frame(self, frame: Frame, direction: FrameDirection):
217
+ await super().process_frame(frame, direction)
218
+
219
+ context = None
220
+ if isinstance(frame, OpenAILLMContextFrame):
221
+ context: OpenAILLMContext = frame.context
222
+ elif isinstance(frame, LLMMessagesFrame):
223
+ context = OpenAILLMContext.from_messages(frame.messages)
224
+ elif isinstance(frame, VisionImageRawFrame):
225
+ context = OpenAILLMContext.from_image_frame(frame)
226
+ else:
227
+ await self.push_frame(frame, direction)
228
+
229
+ if context:
230
+ await self.push_frame(LLMFullResponseStartFrame())
231
+ await self.start_processing_metrics()
232
+ await self._process_context(context)
233
+ await self.stop_processing_metrics()
234
+ await self.push_frame(LLMFullResponseEndFrame())
235
+
236
+
237
+ class OpenAILLMService(BaseOpenAILLMService):
238
+
239
+ def __init__(self, *, model: str = "gpt-4o", **kwargs):
240
+ super().__init__(model=model, **kwargs)
241
+
242
+
243
+ class OpenAIImageGenService(ImageGenService):
244
+
245
+ def __init__(
246
+ self,
247
+ *,
248
+ image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
249
+ aiohttp_session: aiohttp.ClientSession,
250
+ api_key: str,
251
+ model: str = "dall-e-3",
252
+ ):
253
+ super().__init__()
254
+ self._model = model
255
+ self._image_size = image_size
256
+ self._client = AsyncOpenAI(api_key=api_key)
257
+ self._aiohttp_session = aiohttp_session
258
+
259
+ async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
260
+ logger.debug(f"Generating image from prompt: {prompt}")
261
+
262
+ image = await self._client.images.generate(
263
+ prompt=prompt,
264
+ model=self._model,
265
+ n=1,
266
+ size=self._image_size
267
+ )
268
+
269
+ image_url = image.data[0].url
270
+
271
+ if not image_url:
272
+ logger.error(f"{self} No image provided in response: {image}")
273
+ yield ErrorFrame("Image generation failed")
274
+ return
275
+
276
+ # Load the image from the url
277
+ async with self._aiohttp_session.get(image_url) as response:
278
+ image_stream = io.BytesIO(await response.content.read())
279
+ image = Image.open(image_stream)
280
+ frame = URLImageRawFrame(image_url, image.tobytes(), image.size, image.format)
281
+ yield frame
282
+
283
+
284
+ class OpenAITTSService(TTSService):
285
+ """This service uses the OpenAI TTS API to generate audio from text.
286
+ The returned audio is PCM encoded at 24kHz. When using the DailyTransport, set the sample rate in the DailyParams accordingly:
287
+ ```
288
+ DailyParams(
289
+ audio_out_enabled=True,
290
+ audio_out_sample_rate=24_000,
291
+ )
292
+ ```
293
+ """
294
+
295
+ def __init__(
296
+ self,
297
+ *,
298
+ api_key: str | None = None,
299
+ base_url: str | None = None,
300
+ sample_rate: int = 24_000,
301
+ voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = "alloy",
302
+ model: Literal["tts-1", "tts-1-hd"] = "tts-1",
303
+ **kwargs):
304
+ super().__init__(**kwargs)
305
+
306
+ self._voice = voice
307
+ self._model = model
308
+ self.sample_rate=sample_rate
309
+ self._client = AsyncOpenAI(api_key=api_key,base_url=base_url)
310
+
311
+ def can_generate_metrics(self) -> bool:
312
+ return True
313
+
314
+ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
315
+ logger.debug(f"Generating TTS: [{text}]")
316
+
317
+ try:
318
+ await self.start_ttfb_metrics()
319
+
320
+ async with self._client.audio.speech.with_streaming_response.create(
321
+ input=text,
322
+ model=self._model,
323
+ voice=self._voice,
324
+ response_format="pcm",
325
+ ) as r:
326
+ if r.status_code != 200:
327
+ error = await r.text()
328
+ logger.error(
329
+ f"{self} error getting audio (status: {r.status_code}, error: {error})")
330
+ yield ErrorFrame(f"Error getting audio (status: {r.status_code}, error: {error})")
331
+ return
332
+ async for chunk in r.iter_bytes(8192):
333
+ if len(chunk) > 0:
334
+ await self.stop_ttfb_metrics()
335
+ frame = AudioRawFrame(chunk, self.sample_rate, 1)
336
+ yield frame
337
+ except BadRequestError as e:
338
+ logger.exception(f"{self} error generating TTS: {e}")