ff_li
oai
e7de191
raw
history blame
2.12 kB
import os
from typing import Dict, Iterator, List, Optional
import openai
from agent.llm.base import BaseChatModel
from typing import Dict, List, Literal, Optional, Union
class ChatAsOAI(BaseChatModel):
def __init__(self, model: str):
super().__init__()
openai.api_base = os.getenv('OPENAI_API_BASE')
openai.api_key = os.getenv('OPENAI_API_KEY', 'EMPTY')
self.model = os.getenv('OPENAI_MODEL_NAME', model)
def _chat_stream(
self,
messages: List[Dict],
stop: Optional[List[str]] = None,
) -> Iterator[str]:
response = openai.ChatCompletion.create(model=self.model,
messages=messages,
stop=stop,
stream=True)
# TODO: error handling
for chunk in response:
if hasattr(chunk.choices[0].delta, 'content'):
yield chunk.choices[0].delta.content
def _chat_no_stream(
self,
messages: List[Dict],
stop: Optional[List[str]] = None,
) -> str:
response = openai.ChatCompletion.create(model=self.model,
messages=messages,
stop=stop,
stream=False)
# TODO: error handling
return response.choices[0].message.content
def chat_with_functions(self,
messages: List[Dict],
functions: Optional[List[Dict]] = None) -> Dict:
if functions:
response = openai.ChatCompletion.create(model=self.model,
messages=messages,
functions=functions)
else:
response = openai.ChatCompletion.create(model=self.model,
messages=messages)
# TODO: error handling
return response.choices[0].message