Spaces:
Running
Running
"""Library of the project.""" | |
# pylint: disable=wrong-import-position | |
# %% IMPORTS | |
__import__("pysqlite3") | |
import functools | |
import os | |
import sys | |
# https://docs.trychroma.com/troubleshooting#sqlite | |
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") | |
import chromadb | |
import openai | |
import tiktoken | |
from chromadb.utils import embedding_functions | |
# %% CONFIGS | |
DATABASE_COLLECTION = "resume" | |
DATABASE_PATH = "database" | |
EMBEDDING_MODEL = "text-embedding-ada-002" | |
ENCODING_NAME = "cl100k_base" | |
ENCODING_OUTPUT_LIMIT = 8191 | |
MODEL_NAME = "gpt-3.5-turbo-16k" | |
MODEL_INPUT_LIMIT = 16_385 | |
MODEL_TEMPERATURE = 0.9 | |
OPENAI_API_KEY = os.environ["OPENAI_API_KEY"] | |
# %% TYPINGS | |
Collection = chromadb.Collection | |
# %% FUNCTIONS | |
def get_language_model( | |
model: str = MODEL_NAME, | |
api_key: str = OPENAI_API_KEY, | |
temperature: float = MODEL_TEMPERATURE, | |
) -> openai.ChatCompletion: | |
"""Get an OpenAI ChatCompletion model.""" | |
openai.api_key = api_key # configure the API key globally | |
return functools.partial( | |
openai.ChatCompletion.create, model=model, temperature=temperature | |
) | |
def get_database_client(path: str) -> chromadb.API: | |
"""Get a persistent client to the Chroma DB.""" | |
settings = chromadb.Settings(allow_reset=True, anonymized_telemetry=False) | |
return chromadb.PersistentClient(path=path, settings=settings) | |
def get_encoding_function(encoding_name: str = ENCODING_NAME) -> tiktoken.Encoding: | |
"""Get the encoding function for OpenAI models.""" | |
return tiktoken.get_encoding(encoding_name=encoding_name).encode | |
def get_embedding_function( | |
model_name: str = EMBEDDING_MODEL, api_key: str = OPENAI_API_KEY | |
) -> embedding_functions.EmbeddingFunction: | |
"""Get the embedding function for Chroma DB collections.""" | |
return embedding_functions.OpenAIEmbeddingFunction( | |
model_name=model_name, api_key=api_key | |
) | |