Spaces:
Runtime error
Runtime error
from typing import Optional, Generator | |
import os | |
from pathlib import Path | |
import tarfile | |
from dataclasses import dataclass | |
import torch | |
import lancedb | |
from lancedb.embeddings import get_registry | |
from huggingface_hub.file_download import hf_hub_download | |
from huggingface_hub import InferenceClient, login | |
from transformers import AutoTokenizer | |
import gradio as gr | |
import argilla as rg | |
import uuid | |
class Settings: | |
"""Settings class to store useful variables for the App.""" | |
LANCEDB: str = "lancedb" | |
LANCEDB_FILE_TAR: str = "lancedb.tar.gz" | |
TOKEN: str = os.getenv("HF_API_TOKEN") | |
LOCAL_DIR: Path = Path.home() / ".cache/argilla_sdk_docs_db" | |
REPO_ID: str = "plaguss/argilla_sdk_docs_queries" | |
TABLE_NAME: str = "docs" | |
MODEL_NAME: str = "plaguss/bge-base-argilla-sdk-matryoshka" | |
DEVICE: str = ( | |
"mps" | |
if torch.backends.mps.is_available() | |
else "cuda" | |
if torch.cuda.is_available() | |
else "cpu" | |
) | |
MODEL_ID: str = "meta-llama/Meta-Llama-3-70B-Instruct" | |
ARGILLA_URL = r"https://plaguss-argilla-sdk-chatbot.hf.space" | |
ARGILLA_API_KEY = os.getenv("ARGILLA_CHATBOT_API_KEY") | |
ARGILLA_DATASET = "chatbot_interactions" | |
settings = Settings() | |
login(token=settings.TOKEN) | |
client_rg = rg.Argilla( | |
api_url=settings.ARGILLA_URL, | |
api_key=settings.ARGILLA_API_KEY | |
) | |
argilla_dataset = client_rg.datasets(settings.ARGILLA_DATASET) | |
def untar_file(source: Path) -> Path: | |
"""Untar and decompress files which have passed by `make_tarfile`. | |
Args: | |
source (Path): Path pointing to a .tag.gz file. | |
Returns: | |
filename (Path): The filename of the file decompressed. | |
""" | |
new_filename = source.parent / source.stem.replace(".tar", "") | |
with tarfile.open(source, "r:gz") as f: | |
f.extractall(source.parent) | |
return new_filename | |
def download_database( | |
repo_id: str, | |
lancedb_file: str = "lancedb.tar.gz", | |
local_dir: Path = Path.home() / ".cache/argilla_sdk_docs_db", | |
token: str = os.getenv("HF_API_TOKEN"), | |
) -> Path: | |
"""Helper function to download the database. Will download a compressed lancedb stored | |
in a Hugging Face repository. | |
Args: | |
repo_id: Name of the repository where the databsase file is stored. | |
lancedb_file: Name of the compressed file containing the lancedb database. | |
Defaults to "lancedb.tar.gz". | |
local_dir: Path where the file will be donwloaded to. Defaults to | |
Path.home()/".cache/argilla_sdk_docs_db". | |
token: Token for the Hugging Face hub API. Defaults to os.getenv("HF_API_TOKEN"). | |
Returns: | |
The path pointing to the database already uncompressed and ready to be used. | |
""" | |
lancedb_download = Path( | |
hf_hub_download( | |
repo_id, lancedb_file, repo_type="dataset", token=token, local_dir=local_dir | |
) | |
) | |
return untar_file(lancedb_download) | |
# Get the model to create the embeddings | |
model = ( | |
get_registry() | |
.get("sentence-transformers") | |
.create(name=settings.MODEL_NAME, device=settings.DEVICE) | |
) | |
class Database: | |
"""Interaction with the vector database to retrieve the chunks. | |
On instantiation, will donwload the lancedb database if nos already found in | |
the expected location. Once ready, the only functionality available is | |
to retrieve the doc chunks to be used as examples for the LLM. | |
""" | |
def __init__(self, settings: Settings) -> None: | |
""" | |
Args: | |
settings: Instance of the settings. | |
""" | |
self.settings = settings | |
self._table: lancedb.table.LanceTable = self.get_table_from_db() | |
def get_table_from_db(self) -> lancedb.table.LanceTable: | |
"""Downloads the database containing the embedded docs. | |
If the file is not found in the expected location, will download it, and | |
then create the connection, open the table and pass it. | |
Returns: | |
The table of the database containing the embedded chunks. | |
""" | |
lancedb_db_path = self.settings.LOCAL_DIR / self.settings.LANCEDB | |
if not lancedb_db_path.exists(): | |
lancedb_db_path = download_database( | |
self.settings.REPO_ID, | |
lancedb_file=self.settings.LANCEDB_FILE_TAR, | |
local_dir=self.settings.LOCAL_DIR, | |
token=self.settings.TOKEN, | |
) | |
db = lancedb.connect(str(lancedb_db_path)) | |
table = db.open_table(self.settings.TABLE_NAME) | |
return table | |
def retrieve_doc_chunks( | |
self, query: str, limit: int = 12, hard_limit: int = 4 | |
) -> str: | |
"""Search for similar queries in the database, and return the context to be passed | |
to the LLM. | |
Args: | |
query: Query from the user. | |
limit: Number of similar items to retrieve. Defaults to 12. | |
hard_limit: Limit of responses to take into account. | |
As we generated repeated questions initially, the database may contain | |
repeated chunks of documents, in the initial `limit` selection, using | |
`hard_limit` we limit to this number the total of unique retrieved chunks. | |
Defaults to 4. | |
Returns: | |
The context to be used by the model to generate the response. | |
""" | |
# Embed the query to use our custom model instead of the default one. | |
embedded_query = model.generate_embeddings([query]) | |
field_to_retrieve = "text" | |
retrieved = ( | |
self._table.search(embedded_query[0]) | |
.metric("cosine") | |
.limit(limit) | |
.select([field_to_retrieve]) # Just grab the chunk to use for context | |
.to_list() | |
) | |
return self._prepare_context(retrieved, hard_limit) | |
def _prepare_context(retrieved: list[dict[str, str]], hard_limit: int) -> str: | |
"""Prepares the examples to be used in the LLM prompt. | |
Args: | |
retrieved: The list of retrieved chunks. | |
hard_limit: Max number of doc pieces to return. | |
Returns: | |
Context to be used by the LLM. | |
""" | |
# We have repeated questions (up to 4) for a given chunk, so we may get repeated chunks. | |
# Request more than necessary and filter them afterwards | |
responses = [] | |
unique_responses = set() | |
for item in retrieved: | |
chunk = item["text"] | |
if chunk not in unique_responses: | |
unique_responses.add(chunk) | |
responses.append(chunk) | |
context = "" | |
for i, item in enumerate(responses[:hard_limit]): | |
if i > 0: | |
context += "\n\n" | |
context += f"---\n{item}" | |
return context | |
database = Database(settings=settings) | |
def get_client_and_tokenizer( | |
model_id: str = settings.MODEL_ID, tokenizer_id: Optional[str] = None | |
) -> tuple[InferenceClient, AutoTokenizer]: | |
"""Obtains the inference client and the tokenizer corresponding to the model. | |
Args: | |
model_id: The name of the model. Currently it must be one in the free tier. | |
Defaults to "meta-llama/Meta-Llama-3-70B-Instruct". | |
tokenizer_id: The name of the corresponding tokenizer. Defaults to None, | |
in which case it will use the same as the `model_id`. | |
Returns: | |
The client and tokenizer chosen. | |
""" | |
if tokenizer_id is None: | |
tokenizer_id = model_id | |
client = InferenceClient() | |
base_url = client._resolve_url(model=model_id, task="text-generation") | |
# Note: We could move to the AsyncClient | |
client = InferenceClient(model=base_url, token=os.getenv("HF_API_TOKEN")) | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) | |
return client, tokenizer | |
client_kwargs = { | |
"stream": True, | |
"max_new_tokens": 512, | |
"do_sample": False, | |
"typical_p": None, | |
"repetition_penalty": None, | |
"temperature": 0.3, | |
"top_p": None, | |
"top_k": None, | |
"stop_sequences": ["<|eot_id|>", "<|end_of_text|>"] | |
if settings.MODEL_ID.startswith("meta-llama/Meta-Llama-3") | |
else None, | |
"seed": None, | |
} | |
client, tokenizer = get_client_and_tokenizer() | |
SYSTEM_PROMPT = """\ | |
You are a support expert in Argilla SDK, whose goal is help users with their questions. | |
As a trustworthy expert, you must provide truthful answers to questions using only the provided documentation snippets, not prior knowledge. | |
Here are guidelines you must follow when responding to user questions: | |
##Purpose and Functionality** | |
- Answer questions related to the Argilla SDK. | |
- Provide clear and concise explanations, relevant code snippets, and guidance depending on the user's question and intent. | |
- Ensure users succeed in effectively understanding and using Argilla's features. | |
- Provide accurate responses to the user's questions. | |
**Specificity** | |
- Be specific and provide details only when required. | |
- Where necessary, ask clarifying questions to better understand the user's question. | |
- Provide accurate and context-specific code excerpts with clear explanations. | |
- Ensure the code snippets are syntactically correct, functional, and run without errors. | |
- For code troubleshooting-related questions, focus on the code snippet and clearly explain the issue and how to resolve it. | |
- Avoid boilerplate code such as imports, installs, etc. | |
**Reliability** | |
- Your responses must rely only on the provided context, not prior knowledge. | |
- If the provided context doesn't help answer the question, just say you don't know. | |
- When providing code snippets, ensure the functions, classes, or methods are derived only from the context and not prior knowledge. | |
- Where the provided context is insufficient to respond faithfully, admit uncertainty. | |
- Remind the user of your specialization in Argilla SDK support when a question is outside your domain of expertise. | |
- Redirect the user to the appropriate support channels - Argilla [community](https://join.slack.com/t/rubrixworkspace/shared_invite/zt-whigkyjn-a3IUJLD7gDbTZ0rKlvcJ5g) when the question is outside your capabilities or you do not have enough context to answer the question. | |
**Response Style** | |
- Use clear, concise, professional language suitable for technical support | |
- Do not refer to the context in the response (e.g., "As mentioned in the context...") instead, provide the information directly in the response. | |
**Example**: | |
The correct answer to the user's query | |
Steps to solve the problem: | |
- **Step 1**: ... | |
- **Step 2**: ... | |
... | |
Here's a code snippet | |
```python | |
# Code example | |
... | |
``` | |
**Explanation**: | |
- Point 1 | |
- Point 2 | |
... | |
""" | |
ARGILLA_BOT_TEMPLATE = """\ | |
Please provide an answer to the following question related to Argilla's new SDK. | |
You can make use of the chunks of documents in the context to help you generating the response. | |
## Query: | |
{message} | |
## Context: | |
{context} | |
""" | |
def prepare_input(message: str, history: list[tuple[str, str]]) -> str: | |
"""Prepares the input to be passed to the LLM. | |
Args: | |
message: Message from the user, the query. | |
history: Previous list of messages from the user and the answers, as a list | |
of tuples with user/assistant messages. | |
Returns: | |
The string with the template formatted to be sent to the LLM. | |
""" | |
# Retrieve the context from the database | |
context = database.retrieve_doc_chunks(message) | |
# Prepare the conversation for the model. | |
conversation = [] | |
for human, bot in history: | |
conversation.append({"role": "user", "content": human}) | |
conversation.append({"role": "assistant", "content": bot}) | |
conversation.insert(0, {"role": "system", "content": SYSTEM_PROMPT}) | |
conversation.append( | |
{ | |
"role": "user", | |
"content": ARGILLA_BOT_TEMPLATE.format(message=message, context=context), | |
} | |
) | |
return tokenizer.apply_chat_template( | |
[conversation], | |
tokenize=False, | |
add_generation_prompt=True, | |
)[0] | |
def create_chat_html(history: list[tuple[str, str]]) -> str: | |
"""Helper function to create a conversation in HTML in argilla. | |
Args: | |
history: History of messages with the chatbot. | |
Returns: | |
HTML formatted conversation. | |
""" | |
chat_html = "" | |
alignments = ["right", "left"] | |
colors = ["#c2e3f7", "#f5f5f5"] | |
for turn in history: | |
# Create the HTML message div with inline styles | |
message_html = "" | |
# To include message still not answered | |
(user, assistant) = turn | |
if assistant is None: | |
turn = (user, ) | |
for i, content in enumerate(turn): | |
message_html += f'<div style="display: flex; justify-content: {alignments[i]}; margin: 10px;">' | |
message_html += f'<div style="background-color: {colors[i]}; padding: 10px; border-radius: 10px; max-width: 70%; word-wrap: break-word;">{content}</div>' | |
message_html += "</div>" | |
# Add the message to the chat HTML | |
chat_html += message_html | |
return chat_html | |
conv_id = str(uuid.uuid4()) | |
def chatty(message: str, history: list[tuple[str, str]]) -> Generator[str, None, None]: | |
"""Main function of the app, contains the interaction with the LLM. | |
Args: | |
message: Message from the user, the query. | |
history: Previous list of messages from the user and the answers, as a list | |
of tuples with user/assistant messages. | |
Yields: | |
The streaming response, it's printed in the interface as it's being received. | |
""" | |
prompt = prepare_input(message, history) | |
partial_response = "" | |
for token_stream in client.text_generation(prompt=prompt, **client_kwargs): | |
partial_response += token_stream | |
yield partial_response | |
global conv_id | |
new_conversation = len(history) == 0 | |
if new_conversation: | |
conv_id = str(uuid.uuid4()) | |
else: | |
history.append((message, None)) | |
# Register to argilla dataset | |
argilla_dataset.records.log( | |
[ | |
{ | |
"instruction": create_chat_html(history) if history else message, | |
"response": partial_response, | |
"conv_id": conv_id, | |
"turn": len(history) | |
}, | |
] | |
) | |
if __name__ == "__main__": | |
import gradio as gr | |
gr.ChatInterface( | |
chatty, | |
chatbot=gr.Chatbot(height=700), | |
textbox=gr.Textbox( | |
placeholder="Ask me about the new argilla SDK", container=False, scale=7 | |
), | |
title="Argilla SDK Chatbot", | |
description="Ask a question about Argilla SDK", | |
theme="soft", | |
examples=[ | |
"How can I connect to an argilla server?", | |
"How can I access a dataset?", | |
"How can I get the current user?", | |
], | |
cache_examples=True, | |
retry_btn=None, | |
).launch() | |