|
from application.llm.base import BaseLLM |
|
from application.core.settings import settings |
|
|
|
class OpenAILLM(BaseLLM): |
|
|
|
def __init__(self, api_key): |
|
global openai |
|
from openai import OpenAI |
|
|
|
self.client = OpenAI( |
|
api_key=api_key, |
|
) |
|
self.api_key = api_key |
|
|
|
def _get_openai(self): |
|
|
|
import openai |
|
|
|
return openai |
|
|
|
def gen(self, model, engine, messages, stream=False, **kwargs): |
|
response = self.client.chat.completions.create(model=model, |
|
messages=messages, |
|
stream=stream, |
|
**kwargs) |
|
|
|
return response.choices[0].message.content |
|
|
|
def gen_stream(self, model, engine, messages, stream=True, **kwargs): |
|
response = self.client.chat.completions.create(model=model, |
|
messages=messages, |
|
stream=stream, |
|
**kwargs) |
|
|
|
for line in response: |
|
|
|
|
|
if line.choices[0].delta.content is not None: |
|
yield line.choices[0].delta.content |
|
|
|
|
|
class AzureOpenAILLM(OpenAILLM): |
|
|
|
def __init__(self, openai_api_key, openai_api_base, openai_api_version, deployment_name): |
|
super().__init__(openai_api_key) |
|
self.api_base = settings.OPENAI_API_BASE, |
|
self.api_version = settings.OPENAI_API_VERSION, |
|
self.deployment_name = settings.AZURE_DEPLOYMENT_NAME, |
|
from openai import AzureOpenAI |
|
self.client = AzureOpenAI( |
|
api_key=openai_api_key, |
|
api_version=settings.OPENAI_API_VERSION, |
|
api_base=settings.OPENAI_API_BASE, |
|
deployment_name=settings.AZURE_DEPLOYMENT_NAME, |
|
) |
|
|
|
def _get_openai(self): |
|
openai = super()._get_openai() |
|
|
|
return openai |
|
|