import torch import os from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer, AutoTokenizer from interface import GemmaLLMInterface from llama_index.core.node_parser import SentenceSplitter from llama_index.embeddings.instructor import InstructorEmbedding import gradio as gr from llama_index.core import ChatPromptTemplate from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, PromptTemplate, load_index_from_storage from llama_index.core.node_parser import SentenceSplitter from huggingface_hub import hf_hub_download from llama_cpp import Llama import spaces from huggingface_hub import login from llama_index.core.memory import ChatMemoryBuffer from typing import Iterator, List, Any from llama_index.core.chat_engine import CondensePlusContextChatEngine from llama_index.core.llms import ChatMessage, MessageRole huggingface_token = os.getenv("HUGGINGFACE_TOKEN") login(huggingface_token) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model_id = "google/gemma-2-2b-it" model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype= torch.bfloat16 if torch.cuda.is_available() else torch.float32, token=True) model.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it") model.eval() # what models will be used by LlamaIndex: Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base") Settings.llm = GemmaLLMInterface() documents_paths = { 'blockchain': 'data/blockchainprova.txt', 'metaverse': 'data/metaverso', 'payment': 'data/payment' } session_state = {"documents_loaded": False, "document_db": None, "original_message": None, "clarification": False} ############################--------------------------------- # Get the parser parser = SentenceSplitter.from_defaults( chunk_size=256, chunk_overlap=64, paragraph_separator="\n\n" ) def build_index(path: str): # Load documents from a file documents = SimpleDirectoryReader(input_files=[path]).load_data() # Parse the documents into nodes nodes = parser.get_nodes_from_documents(documents) # Build the vector store index from the nodes index = VectorStoreIndex(nodes) return index @spaces.GPU(duration=20) def handle_query(query_str: str, chat_history: list[tuple[str, str]], session: dict[str, Any]) -> Iterator[str]: global index if not session["index"]: matched_path = None words = query_str.lower() for key, path in documents_paths.items(): if key in words: matched_path = path break if matched_path: index = build_index(matched_path) session["index"] = True else: ## CHIEDI CHIARIMENTO conversation: List[ChatMessage] = [] for user, assistant in chat_history: conversation.extend( [ ChatMessage(role=MessageRole.USER, content=user), ChatMessage(role=MessageRole.ASSISTANT, content=assistant), ] ) index = build_index("data/chiarimento.txt") else: # The index is already built, no need to rebuild it. conversation: List[ChatMessage] = [] for user, assistant in chat_history: conversation.extend( [ ChatMessage(role=MessageRole.USER, content=user), ChatMessage(role=MessageRole.ASSISTANT, content=assistant), ] ) #conversation.append( ChatMessage(role=MessageRole.USER, content=query_str)) #pass try: memory = ChatMemoryBuffer.from_defaults(token_limit=None) chat_engine = index.as_chat_engine( chat_mode="condense_plus_context", memory=memory, similarity_top_k=4, response_mode="tree_summarize", #Good for summarization purposes context_prompt = ( "Sei un assistente Q&A italiano di nome Odi, che risponde solo alle domande o richieste pertinenti in modo preciso." " Quando un utente ti chiede informazioni su di te o sul tuo creatore puoi dire che sei un assistente ricercatore creato dagli Osservatori Digitali e fornire gli argomenti di cui sei esperto." " Ecco i documenti rilevanti per il contesto:\n" "{context_str}" "\nIstruzione: Usa la cronologia delle chat precedenti, o il contesto sopra, per interagire e aiutare l'utente a rispondere alla sua domanda." ), verbose=False, ) outputs = [] response = chat_engine.stream_chat(query_str, conversation) #response = chat_engine.chat(query_str) for token in response.response_gen: #if not token.startswith("system:") and not token.startswith("user:"): outputs.append(token) #print(f"Generated token: {token}") yield "".join(outputs) except Exception as e: yield f"Error processing query: {str(e)}"