chatbot-llamaindex / backend.py
gufett0's picture
changed class interface with iterator
716b08f
raw
history blame
3.05 kB
import torch
import os
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer, AutoTokenizer
from interface import GemmaLLMInterface
from llama_index.core.node_parser import SentenceSplitter
from llama_index.embeddings.instructor import InstructorEmbedding
import gradio as gr
from llama_index.core import ChatPromptTemplate
from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, PromptTemplate, load_index_from_storage
from llama_index.core.node_parser import SentenceSplitter
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
import spaces
from huggingface_hub import login
from llama_index.core.memory import ChatMemoryBuffer
from typing import Iterator
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
login(huggingface_token)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_id = "google/gemma-2-2b-it"
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype= torch.bfloat16 if torch.cuda.is_available() else torch.float32,
token=True)
model.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
model.eval()
# what models will be used by LlamaIndex:
Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
Settings.llm = GemmaLLMInterface()
############################---------------------------------
# Get the parser
parser = SentenceSplitter.from_defaults(
chunk_size=256, chunk_overlap=64, paragraph_separator="\n\n"
)
def build_index():
# Load documents from a file
documents = SimpleDirectoryReader(input_files=["data/blockchainprova.txt"]).load_data()
# Parse the documents into nodes
nodes = parser.get_nodes_from_documents(documents)
# Build the vector store index from the nodes
index = VectorStoreIndex(nodes)
return index
@spaces.GPU(duration=20)
def handle_query(query_str, chathistory) -> Iterator[str]:
index = build_index()
try:
memory = ChatMemoryBuffer.from_defaults(token_limit=1500)
chat_engine = index.as_chat_engine(
chat_mode="context",
memory=memory,
system_prompt=(
"Sei un assistente Q&A italiano di nome Odi, che risponde solo alle domande o richieste pertinenti in modo preciso. Hai una risposta predefinita per quando un utente ti chiede informazioni su di te o sul tuo creatore, ovvero: 'Sono un assistente ricercatore creato dagli Osservatori Digitali'."
),
)
outputs = []
response = chat_engine.stream_chat(query_str)
#response = chat_engine.chat(query_str)
for token in response.response_gen:
if not token.startswith("system:") and not token.startswith("user:"):
outputs.append(str(token))
print(f"Generated token: {token}")
yield "".join(outputs)
except Exception as e:
yield f"Error processing query: {str(e)}"