chatbot-llamaindex / backend.py
gufett0's picture
added new class
c611543
raw
history blame
4.09 kB
import torch
import os
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer, AutoTokenizer
from interface import GemmaLLMInterface
from llama_index.core.node_parser import SentenceSplitter
from llama_index.embeddings.instructor import InstructorEmbedding
import gradio as gr
from llama_index.core import ChatPromptTemplate
from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, PromptTemplate, load_index_from_storage
from llama_index.core.node_parser import SentenceSplitter
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
import spaces
from huggingface_hub import login
from llama_index.core.memory import ChatMemoryBuffer
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
login(huggingface_token)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
"""model_id = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto", ## change this back to auto!!!
torch_dtype= torch.bfloat16 if torch.cuda.is_available() else torch.float32,
token=True)
model.eval()"""
#from accelerate import disk_offload
#disk_offload(model=model, offload_dir="offload")
# what models will be used by LlamaIndex:
"""Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
Settings.llm = GemmaLLMInterface(model=model, tokenizer=tokenizer)"""
Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
Settings.llm = GemmaLLMInterface(model_id="google/gemma-2-2b-it")
############################---------------------------------
# Get the parser
parser = SentenceSplitter.from_defaults(
chunk_size=256, chunk_overlap=64, paragraph_separator="\n\n"
)
def build_index():
# Load documents from a file
documents = SimpleDirectoryReader(input_files=["data/blockchainprova.txt"]).load_data()
# Parse the documents into nodes
nodes = parser.get_nodes_from_documents(documents)
# Build the vector store index from the nodes
index = VectorStoreIndex(nodes)
return index
@spaces.GPU(duration=20)
def handle_query(query_str, chathistory):
index = build_index()
qa_prompt_str = (
"Context information is below.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given the context information and not prior knowledge, "
"answer the question: {query_str}\n"
)
# Text QA Prompt
chat_text_qa_msgs = [
(
"system",
"Sei un assistente italiano di nome Ossy che risponde solo alle domande o richieste pertinenti. ",
),
("user", qa_prompt_str),
]
text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
try:
# Create a streaming query engine
"""query_engine = index.as_query_engine(text_qa_template=text_qa_template, streaming=False, similarity_top_k=1)
# Execute the query
streaming_response = query_engine.query(query_str)
r = streaming_response.response
cleaned_result = r.replace("<end_of_turn>", "").strip()
yield cleaned_result"""
# Stream the response
"""outputs = []
for text in streaming_response.response_gen:
outputs.append(str(text))
yield "".join(outputs)"""
memory = ChatMemoryBuffer.from_defaults(token_limit=1500)
chat_engine = index.as_chat_engine(
chat_mode="context",
memory=memory,
system_prompt=(
"Sei un assistente italiano di nome Ossy che risponde solo alle domande o richieste pertinenti. "
),
)
response = chat_engine.stream_chat(query_str)
#response = chat_engine.chat(query_str)
for token in response.response_gen:
yield token
except Exception as e:
yield f"Error processing query: {str(e)}"