chatbot-llamaindex / backend.py
gufett0's picture
hf llm
e76e7f8
raw
history blame
9.1 kB
import torch
import os
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer, AutoTokenizer
#from interface import GemmaLLMInterface
from llama_index.embeddings.instructor import InstructorEmbedding
import gradio as gr
from llama_index.core import Settings, ServiceContext, VectorStoreIndex, SimpleDirectoryReader, ChatPromptTemplate, PromptTemplate, load_index_from_storage, StorageContext
from llama_index.core.node_parser import SentenceSplitter
import spaces
from huggingface_hub import login
from llama_index.core.memory import ChatMemoryBuffer
from typing import Iterator, List, Any
from llama_index.core.chat_engine import CondensePlusContextChatEngine
from llama_index.core.llms import ChatMessage, MessageRole , CompletionResponse
from IPython.display import Markdown, display
from langchain_huggingface import HuggingFaceEmbeddings
#from llama_index import LangchainEmbedding, ServiceContext
#from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.huggingface import HuggingFaceInferenceAPI, HuggingFaceLLM
from dotenv import load_dotenv
import logging
import sys
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
login(huggingface_token)
"""huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
login(huggingface_token)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_id = "google/gemma-2-2b-it"
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype= torch.bfloat16 if torch.cuda.is_available() else torch.float32,
token=True)
tokenizer= AutoTokenizer.from_pretrained("google/gemma-2b-it")
model.tokenizer = tokenizer
model.eval()"""
system_prompt="""
You are a Q&A assistant. Your goal is to answer questions as
accurately as possible based on the instructions and context provided.
"""
load_dotenv()
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
llm = HuggingFaceLLM(
context_window=4096,
max_new_tokens=256,
generate_kwargs={"temperature": 0.1, "do_sample": True},
system_prompt=system_prompt,
tokenizer_name="meta-llama/Llama-2-7b-chat-hf",
model_name="meta-llama/Llama-2-7b-chat-hf",
device_map="auto",
# loading model in 8bit for reducing memory
model_kwargs={"torch_dtype": torch.float16 }
)
embed_model= HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
Settings.llm = llm
Settings.embed_model = embed_model
#Settings.node_parser = SentenceSplitter(chunk_size=512, chunk_overlap=20, paragraph_separator="\n\n")
Settings.num_output = 512
Settings.context_window = 3900
documents = SimpleDirectoryReader('./data').load_data()
nodes = SentenceSplitter(chunk_size=512, chunk_overlap=20, paragraph_separator="\n\n").get_nodes_from_documents(documents)
# Build the vector store index from the nodes
# what models will be used by LlamaIndex:
#Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
#Settings.embed_model = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
#Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
#Settings.llm = GemmaLLMInterface()
documents_paths = {
'blockchain': 'data/blockchainprova.txt',
'metaverse': 'data/metaverseprova.txt',
'payment': 'data/paymentprova.txt'
}
global session_state
session_state = {"index": False,
"documents_loaded": False,
"document_db": None,
"original_message": None,
"clarification": False}
PERSIST_DIR = "./db"
os.makedirs(PERSIST_DIR, exist_ok=True)
ISTR = "In italiano, chiedi molto brevemente se la domanda si riferisce agli 'Osservatori Blockchain', 'Osservatori Payment' oppure 'Osservatori Metaverse'."
############################---------------------------------
# Get the parser
"""parser = SentenceSplitter.from_defaults(
chunk_size=256, chunk_overlap=64, paragraph_separator="\n\n"
)
def build_index(path: str):
# Load documents from a file
documents = SimpleDirectoryReader(input_files=[path]).load_data()
# Parse the documents into nodes
nodes = parser.get_nodes_from_documents(documents)
# Build the vector store index from the nodes
index = VectorStoreIndex(nodes)
#storage_context = StorageContext.from_defaults()
#index.storage_context.persist(persist_dir=PERSIST_DIR)
return index"""
@spaces.GPU(duration=15)
def handle_query(query_str: str,
chat_history: list[tuple[str, str]]) -> Iterator[str]:
#index= build_index("data/blockchainprova.txt")
index = VectorStoreIndex(nodes, show_progress = True)
conversation: List[ChatMessage] = []
for user, assistant in chat_history:
conversation.extend([
ChatMessage(role=MessageRole.USER, content=user),
ChatMessage(role=MessageRole.ASSISTANT, content=assistant),
]
)
"""if not session_state["index"]:
matched_path = None
words = query_str.lower()
for key, path in documents_paths.items():
if key in words:
matched_path = path
break
if matched_path:
index = build_index(matched_path)
gr.Info("index costruito con la path sulla base della query")
session_state["index"] = True
else: ## CHIEDI CHIARIMENTO
conversation.append(ChatMessage(role=MessageRole.SYSTEM, content=ISTR))
index = build_index("data/blockchainprova.txt")
gr.Info("index costruito con richiesta di chiarimento")
else:
index = build_index(matched_path)
#storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
#index = load_index_from_storage(storage_context)
gr.Info("index is true")"""
try:
memory = ChatMemoryBuffer.from_defaults(token_limit=1500)
"""chat_engine = index.as_chat_engine(
chat_mode="condense_plus_context",
memory=memory,
similarity_top_k=3,
response_mode= "tree_summarize", #Good for summarization purposes
context_prompt = (
"Sei un assistente Q&A italiano di nome Odi, che risponde solo alle domande o richieste pertinenti in modo preciso."
" Quando un utente ti chiede informazioni su di te o sul tuo creatore puoi dire che sei un assistente ricercatore creato dagli Osservatori Digitali e fornire gli argomenti di cui sei esperto."
" Ecco i documenti rilevanti per il contesto:\n"
"{context_str}"
"\nIstruzione: Usa la cronologia della chat, o il contesto sopra, per interagire e aiutare l'utente a rispondere alla sua domanda."
),
verbose=False,
)"""
print("chat engine..")
gr.Info("chat engine..")
chat_engine = index.as_chat_engine(
chat_mode="context",
similarity_top_k=3,
memory=memory,
context_prompt=(
"Sei un assistente Q&A italiano di nome Odi, che risponde solo alle domande o richieste pertinenti in modo preciso."
" Usa la cronologia della chat, o il contesto fornito, per interagire e aiutare l'utente a rispondere alla sua domanda."
),
)
"""retriever = index.as_retriever(similarity_top_k=3)
# Let's test it out
relevant_chunks = relevant_chunks = retriever.retrieve(query_str)
print(f"Found: {len(relevant_chunks)} relevant chunks")
for idx, chunk in enumerate(relevant_chunks):
info_message += f"{idx + 1}) {chunk.text[:64]}...\n"
print(info_message)
gr.Info(info_message)"""
#prompts_dict = chat_engine.get_prompts()
#display_prompt_dict(prompts_dict)
#chat_engine.reset()
outputs = []
#response = query_engine.query(query_str)
response = chat_engine.stream_chat(query_str, chat_history=conversation)
sources = [] # Use a list to collect multiple sources if present
#response = chat_engine.chat(query_str)
for token in response.response_gen:
if token.startswith("assistant:"):
# Remove the "assistant:" prefix
outputs.append(token[len("assistant:"):])
print(f"Generated token: {token}")
yield "".join(outputs)
#yield CompletionResponse(text=''.join(outputs), delta=token)
"""if sources:
sources_str = ", ".join(sources)
outputs.append(f"Fonti utilizzate: {sources_str}")
else:
outputs.append("Nessuna fonte specifica utilizzata.")
yield "".join(outputs)"""
except Exception as e:
yield f"Error processing query: {str(e)}"