import torch import os from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer, AutoTokenizer from interface import GemmaLLMInterface from llama_index.core.node_parser import SentenceSplitter from llama_index.embeddings.instructor import InstructorEmbedding import gradio as gr from llama_index.core import ChatPromptTemplate from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, PromptTemplate, load_index_from_storage from llama_index.core.node_parser import SentenceSplitter from huggingface_hub import hf_hub_download from llama_cpp import Llama import spaces from huggingface_hub import login from llama_index.core.memory import ChatMemoryBuffer huggingface_token = os.getenv("HUGGINGFACE_TOKEN") login(huggingface_token) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model_id = "google/gemma-2-2b-it" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", ## change this back to auto!!! torch_dtype= torch.bfloat16 if torch.cuda.is_available() else torch.float32, token=True) model.eval() #from accelerate import disk_offload #disk_offload(model=model, offload_dir="offload") # what models will be used by LlamaIndex: """Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base") Settings.llm = GemmaLLMInterface(model=model, tokenizer=tokenizer)""" Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base") Settings.llm = GemmaLLMInterface(model_id="google/gemma-2-2b-it") ############################--------------------------------- # Get the parser parser = SentenceSplitter.from_defaults( chunk_size=256, chunk_overlap=64, paragraph_separator="\n\n" ) def build_index(): # Load documents from a file documents = SimpleDirectoryReader(input_files=["data/blockchainprova.txt"]).load_data() # Parse the documents into nodes nodes = parser.get_nodes_from_documents(documents) # Build the vector store index from the nodes index = VectorStoreIndex(nodes) return index @spaces.GPU(duration=20) def handle_query(query_str, chathistory): index = build_index() qa_prompt_str = ( "Context information is below.\n" "---------------------\n" "{context_str}\n" "---------------------\n" "Given the context information and not prior knowledge, " "answer the question: {query_str}\n" ) # Text QA Prompt chat_text_qa_msgs = [ ( "system", "Sei un assistente italiano di nome Ossy che risponde solo alle domande o richieste pertinenti. ", ), ("user", qa_prompt_str), ] text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs) try: # Create a streaming query engine """query_engine = index.as_query_engine(text_qa_template=text_qa_template, streaming=False, similarity_top_k=1) # Execute the query streaming_response = query_engine.query(query_str) r = streaming_response.response cleaned_result = r.replace("", "").strip() yield cleaned_result""" # Stream the response """outputs = [] for text in streaming_response.response_gen: outputs.append(str(text)) yield "".join(outputs)""" memory = ChatMemoryBuffer.from_defaults(token_limit=1500) chat_engine = index.as_chat_engine( chat_mode="context", memory=memory, system_prompt=( "Sei un assistente italiano di nome Ossy che risponde solo alle domande o richieste pertinenti. " ), ) response = chat_engine.stream_chat(query_str) #response = chat_engine.chat(query_str) for token in response.response_gen: yield token except Exception as e: yield f"Error processing query: {str(e)}"