import torch import os from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer, AutoTokenizer from interface import GemmaLLMInterface from llama_index.core.node_parser import SentenceSplitter from llama_index.embeddings.instructor import InstructorEmbedding import gradio as gr from llama_index.core import ChatPromptTemplate from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, PromptTemplate, load_index_from_storage from llama_index.core.node_parser import SentenceSplitter from huggingface_hub import hf_hub_download from llama_cpp import Llama import spaces from huggingface_hub import login from llama_index.core.memory import ChatMemoryBuffer from typing import Iterator huggingface_token = os.getenv("HUGGINGFACE_TOKEN") login(huggingface_token) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model_id = "google/gemma-2-2b-it" model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype= torch.bfloat16 if torch.cuda.is_available() else torch.float32, token=True) model.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it") model.eval() # what models will be used by LlamaIndex: Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base") Settings.llm = GemmaLLMInterface() ############################--------------------------------- # Get the parser parser = SentenceSplitter.from_defaults( chunk_size=256, chunk_overlap=64, paragraph_separator="\n\n" ) def build_index(): # Load documents from a file documents = SimpleDirectoryReader(input_files=["data/blockchainprova.txt"]).load_data() # Parse the documents into nodes nodes = parser.get_nodes_from_documents(documents) # Build the vector store index from the nodes index = VectorStoreIndex(nodes) return index @spaces.GPU(duration=20) def handle_query(query_str, chathistory) -> Iterator[str]: index = build_index() try: memory = ChatMemoryBuffer.from_defaults(token_limit=1500) chat_engine = index.as_chat_engine( chat_mode="context", memory=memory, system_prompt=( "Sei un assistente Q&A italiano di nome Odi, che risponde solo alle domande o richieste pertinenti in modo preciso. Hai una risposta predefinita per quando un utente ti chiede informazioni su di te o sul tuo creatore, ovvero: 'Sono un assistente ricercatore creato dagli Osservatori Digitali'." ), ) outputs = [] response = chat_engine.stream_chat(query_str) #response = chat_engine.chat(query_str) for token in response.response_gen: if not token.startswith("system:") and not token.startswith("user:"): outputs.append(str(token)) print(f"Generated token: {token}") yield "".join(outputs) except Exception as e: yield f"Error processing query: {str(e)}"