chatbot-llamaindex / backend.py
gufett0's picture
added cuda support
3f367eb
raw
history blame
3.17 kB
import torch
import os
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer, AutoTokenizer
from interface import GemmaLLMInterface
from llama_index.core.node_parser import SentenceSplitter
from llama_index.embeddings.instructor import InstructorEmbedding
import gradio as gr
from llama_index.core import ChatPromptTemplate
from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, PromptTemplate, load_index_from_storage
from llama_index.core.node_parser import SentenceSplitter
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
import spaces
from huggingface_hub import login
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
login(huggingface_token)
"""hf_hub_download(
repo_id="google/gemma-2-2b-it-GGUF",
filename="2b_it_v2.gguf",
local_dir="./models",
token=huggingface_token
)
llm = Llama(
model_path=f"models/2b_it_v2.gguf",
flash_attn=True,
_gpu_layers=81,
n_batch=1024,
n_ctx=8192,
)"""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_id = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype= torch.bfloat16 if torch.cuda.is_available() else torch.float32,
token=True
)
model.eval()
# what models will be used by LlamaIndex:
Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
Settings.llm = GemmaLLMInterface(model=model, tokenizer=tokenizer)
#Settings.llm = llm
############################---------------------------------
@spaces.GPU(duration=120)
def handle_query(query_str, chathistory):
# Reading documents from disk
documents = SimpleDirectoryReader(input_files=["data/blockchainprova.txt"]).load_data()
# Splitting the document into chunks
parser = SentenceSplitter.from_defaults(
chunk_size=256, chunk_overlap=64, paragraph_separator="\n\n"
)
nodes = parser.get_nodes_from_documents(documents)
# BUILD A VECTOR STORE
index = VectorStoreIndex(nodes)
qa_prompt_str = (
"Context information is below.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given the context information and not prior knowledge, "
"answer the question: {query_str}\n"
)
# Text QA Prompt
chat_text_qa_msgs = [
(
"system",
"Sei un assistente italiano di nome Ossy che risponde solo alle domande o richieste pertinenti. ",
),
("user", qa_prompt_str),
]
text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
try:
result = index.as_query_engine(text_qa_template=text_qa_template).query(query_str)
response_text = result.response
# Remove any unwanted tokens like <end_of_turn>
cleaned_result = response_text.replace("<end_of_turn>", "").strip()
yield cleaned_result
except Exception as e:
yield f"Error processing query: {str(e)}"