chatbot-llamaindex / backend.py
gufett0's picture
added introductory prompt
6ed7896
raw
history blame
5.04 kB
import torch
import os
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer, AutoTokenizer
from interface import GemmaLLMInterface
from llama_index.core.node_parser import SentenceSplitter
from llama_index.embeddings.instructor import InstructorEmbedding
import gradio as gr
from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, PromptTemplate, load_index_from_storage, StorageContext
from llama_index.core.node_parser import SentenceSplitter
import spaces
from huggingface_hub import login
from llama_index.core.memory import ChatMemoryBuffer
from typing import Iterator, List, Any
from llama_index.core.chat_engine import CondensePlusContextChatEngine
from llama_index.core.llms import ChatMessage, MessageRole
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
login(huggingface_token)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_id = "google/gemma-2-2b-it"
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype= torch.bfloat16 if torch.cuda.is_available() else torch.float32,
token=True)
model.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
model.eval()
# what models will be used by LlamaIndex:
Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
Settings.llm = GemmaLLMInterface()
documents_paths = {
'blockchain': 'data/blockchainprova.txt',
'metaverse': 'data/metaverso',
'payment': 'data/payment'
}
session_state = {"index": False,
"documents_loaded": False,
"document_db": None,
"original_message": None,
"clarification": False}
PERSIST_DIR = "./db"
os.makedirs(PERSIST_DIR, exist_ok=True)
############################---------------------------------
# Get the parser
parser = SentenceSplitter.from_defaults(
chunk_size=256, chunk_overlap=64, paragraph_separator="\n\n"
)
def build_index(path: str):
# Load documents from a file
documents = SimpleDirectoryReader(input_files=[path]).load_data()
# Parse the documents into nodes
nodes = parser.get_nodes_from_documents(documents)
# Build the vector store index from the nodes
index = VectorStoreIndex(nodes)
storage_context = StorageContext.from_defaults()
index.storage_context.persist(persist_dir=PERSIST_DIR)
return index
@spaces.GPU(duration=20)
def handle_query(query_str: str,
chat_history: list[tuple[str, str]]) -> Iterator[str]:
#global conversation
conversation: List[ChatMessage] = []
for user, assistant in chat_history:
conversation.extend([
ChatMessage(role=MessageRole.USER, content=user),
ChatMessage(role=MessageRole.ASSISTANT, content=assistant),
]
)
if not session_state["index"]:
matched_path = None
words = query_str.lower()
for key, path in documents_paths.items():
if key in words:
matched_path = path
break
if matched_path:
index = build_index(matched_path)
session_state["index"] = True
else: ## CHIEDI CHIARIMENTO
index = build_index("data/chiarimento.txt")
else:
# The index is already built, no need to rebuild it.
storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
index = load_index_from_storage(storage_context)
try:
memory = ChatMemoryBuffer.from_defaults(token_limit=None)
chat_engine = index.as_chat_engine(
chat_mode="condense_plus_context",
memory=memory,
similarity_top_k=4,
response_mode="tree_summarize", #Good for summarization purposes
context_prompt = (
"Sei un assistente Q&A italiano di nome Odi, che risponde solo alle domande o richieste pertinenti in modo preciso."
" Quando un utente ti chiede informazioni su di te o sul tuo creatore puoi dire che sei un assistente ricercatore creato dagli Osservatori Digitali e fornire gli argomenti di cui sei esperto."
" Ecco i documenti rilevanti per il contesto:\n"
"{context_str}"
"\nIstruzione: Usa la cronologia delle chat precedenti, o il contesto sopra, per interagire e aiutare l'utente a rispondere alla sua domanda."
),
verbose=False,
)
outputs = []
response = chat_engine.stream_chat(query_str, conversation)
#response = chat_engine.chat(query_str)
for token in response.response_gen:
#if not token.startswith("system:") and not token.startswith("user:"):
outputs.append(token)
#print(f"Generated token: {token}")
yield "".join(outputs)
except Exception as e:
yield f"Error processing query: {str(e)}"