"""Library of the project.""" # pylint: disable=wrong-import-position # %% IMPORTS __import__("pysqlite3") import functools import os import sys # https://docs.trychroma.com/troubleshooting#sqlite sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") import chromadb import openai import tiktoken from chromadb.utils import embedding_functions # %% CONFIGS DATABASE_COLLECTION = "resume" DATABASE_PATH = "database" EMBEDDING_MODEL = "text-embedding-ada-002" ENCODING_NAME = "cl100k_base" ENCODING_OUTPUT_LIMIT = 8191 MODEL_NAME = "gpt-3.5-turbo-16k" MODEL_INPUT_LIMIT = 16_385 MODEL_TEMPERATURE = 0.9 OPENAI_API_KEY = os.environ["OPENAI_API_KEY"] # %% TYPINGS Collection = chromadb.Collection # %% FUNCTIONS def get_language_model( model: str = MODEL_NAME, api_key: str = OPENAI_API_KEY, temperature: float = MODEL_TEMPERATURE, ) -> openai.ChatCompletion: """Get an OpenAI ChatCompletion model.""" openai.api_key = api_key # configure the API key globally return functools.partial( openai.ChatCompletion.create, model=model, temperature=temperature ) def get_database_client(path: str) -> chromadb.API: """Get a persistent client to the Chroma DB.""" settings = chromadb.Settings(allow_reset=True, anonymized_telemetry=False) return chromadb.PersistentClient(path=path, settings=settings) def get_encoding_function(encoding_name: str = ENCODING_NAME) -> tiktoken.Encoding: """Get the encoding function for OpenAI models.""" return tiktoken.get_encoding(encoding_name=encoding_name).encode def get_embedding_function( model_name: str = EMBEDDING_MODEL, api_key: str = OPENAI_API_KEY ) -> embedding_functions.EmbeddingFunction: """Get the embedding function for Chroma DB collections.""" return embedding_functions.OpenAIEmbeddingFunction( model_name=model_name, api_key=api_key )