from __future__ import annotations import openai from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam import os from typing import List, Optional, ClassVar import enum from llm_handler.llm_interface import LLMInterface, DefaultEnumMeta class ChatModelVersion(enum.Enum, metaclass=DefaultEnumMeta): GPT_3_5 = 'gpt-3.5-turbo-1106' GPT_4 = 'gpt-4' GPT_4_TURBO = 'gpt-4-1106-preview' GPT_4_O = 'gpt-4o' class EmbeddingModelVersion(enum.Enum, metaclass=DefaultEnumMeta): SMALL_3 = 'text-embedding-3-small' ADA_002 = 'text-embedding-ada-002' LARGE = 'text-embedding-3-large' class OpenAIHandler(LLMInterface): _ENV_KEY_NAME: ClassVar[str] = 'OPENAI_API_KEY' _client: openai.Client def __init__(self, openai_api_key: Optional[str] = None): _openai_api_key = openai_api_key or os.environ.get(self._ENV_KEY_NAME) if not _openai_api_key: raise ValueError(f'{self._ENV_KEY_NAME} not set') openai.api_key = _openai_api_key self._client = openai.Client() def get_chat_completion( # type: ignore self, messages: List[ChatCompletionMessageParam], model: ChatModelVersion = ChatModelVersion.GPT_4_O, temperature: float = 0.2, **kwargs) -> str: response = self._client.chat.completions.create(model=model.value, messages=messages, temperature=temperature, **kwargs) responses: List[str] = [] for choice in response.choices: if choice.finish_reason != 'stop' or not choice.message.content: raise ValueError(f'Choice did not complete correctly: {choice}') responses.append(choice.message.content) if len(responses) != 1: raise ValueError(f'Expected one response, got {len(responses)}: {responses}') return responses[0] def get_text_embedding( # type: ignore self, input: str, model: EmbeddingModelVersion) -> List[float]: response = self._client.embeddings.create(model=model.value, encoding_format='float', input=input) if not response.data: raise ValueError(f'No embedding in response: {response}') elif len(response.data) != 1: raise ValueError(f'More than one embedding in response: {response}') return response.data[0].embedding