Spaces:
Sleeping
Sleeping
File size: 4,085 Bytes
643e1b9 aac5496 643e1b9 1275101 8c678cf ac12a64 b7aed3a 1275101 b210fbe b910146 ac12a64 1275101 231b62a 643e1b9 c611543 3f367eb 643e1b9 baf000f aac5496 baf000f c611543 baf000f 643e1b9 f7aeb1e 643e1b9 f7aeb1e 643e1b9 708da42 b52ede2 08c9e9f 5592cea 08c9e9f 5592cea 08c9e9f 5592cea 08c9e9f 643e1b9 03d2fc2 f7aeb1e ed51056 f7aeb1e ed51056 f7aeb1e ed51056 f7aeb1e ed51056 9a196a8 f7aeb1e ed51056 0328982 8238f47 643e1b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import torch
import os
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer, AutoTokenizer
from interface import GemmaLLMInterface
from llama_index.core.node_parser import SentenceSplitter
from llama_index.embeddings.instructor import InstructorEmbedding
import gradio as gr
from llama_index.core import ChatPromptTemplate
from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, PromptTemplate, load_index_from_storage
from llama_index.core.node_parser import SentenceSplitter
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
import spaces
from huggingface_hub import login
from llama_index.core.memory import ChatMemoryBuffer
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
login(huggingface_token)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
"""model_id = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto", ## change this back to auto!!!
torch_dtype= torch.bfloat16 if torch.cuda.is_available() else torch.float32,
token=True)
model.eval()"""
#from accelerate import disk_offload
#disk_offload(model=model, offload_dir="offload")
# what models will be used by LlamaIndex:
"""Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
Settings.llm = GemmaLLMInterface(model=model, tokenizer=tokenizer)"""
Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
Settings.llm = GemmaLLMInterface(model_id="google/gemma-2-2b-it")
############################---------------------------------
# Get the parser
parser = SentenceSplitter.from_defaults(
chunk_size=256, chunk_overlap=64, paragraph_separator="\n\n"
)
def build_index():
# Load documents from a file
documents = SimpleDirectoryReader(input_files=["data/blockchainprova.txt"]).load_data()
# Parse the documents into nodes
nodes = parser.get_nodes_from_documents(documents)
# Build the vector store index from the nodes
index = VectorStoreIndex(nodes)
return index
@spaces.GPU(duration=20)
def handle_query(query_str, chathistory):
index = build_index()
qa_prompt_str = (
"Context information is below.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given the context information and not prior knowledge, "
"answer the question: {query_str}\n"
)
# Text QA Prompt
chat_text_qa_msgs = [
(
"system",
"Sei un assistente italiano di nome Ossy che risponde solo alle domande o richieste pertinenti. ",
),
("user", qa_prompt_str),
]
text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
try:
# Create a streaming query engine
"""query_engine = index.as_query_engine(text_qa_template=text_qa_template, streaming=False, similarity_top_k=1)
# Execute the query
streaming_response = query_engine.query(query_str)
r = streaming_response.response
cleaned_result = r.replace("<end_of_turn>", "").strip()
yield cleaned_result"""
# Stream the response
"""outputs = []
for text in streaming_response.response_gen:
outputs.append(str(text))
yield "".join(outputs)"""
memory = ChatMemoryBuffer.from_defaults(token_limit=1500)
chat_engine = index.as_chat_engine(
chat_mode="context",
memory=memory,
system_prompt=(
"Sei un assistente italiano di nome Ossy che risponde solo alle domande o richieste pertinenti. "
),
)
response = chat_engine.stream_chat(query_str)
#response = chat_engine.chat(query_str)
for token in response.response_gen:
yield token
except Exception as e:
yield f"Error processing query: {str(e)}"
|