Spaces:
Sleeping
Sleeping
import torch | |
import os | |
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer, AutoTokenizer | |
from interface import GemmaLLMInterface | |
from llama_index.core.node_parser import SentenceSplitter | |
from llama_index.embeddings.instructor import InstructorEmbedding | |
import gradio as gr | |
from llama_index.core import ChatPromptTemplate | |
from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, PromptTemplate, load_index_from_storage | |
from llama_index.core.node_parser import SentenceSplitter | |
from huggingface_hub import hf_hub_download | |
from llama_cpp import Llama | |
import spaces | |
from huggingface_hub import login | |
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") | |
login(huggingface_token) | |
"""hf_hub_download( | |
repo_id="google/gemma-2-2b-it-GGUF", | |
filename="2b_it_v2.gguf", | |
local_dir="./models", | |
token=huggingface_token | |
) | |
llm = Llama( | |
model_path=f"models/2b_it_v2.gguf", | |
flash_attn=True, | |
_gpu_layers=81, | |
n_batch=1024, | |
n_ctx=8192, | |
)""" | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
model_id = "google/gemma-2-2b-it" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", | |
torch_dtype= torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
token=True | |
) | |
model.eval() | |
# what models will be used by LlamaIndex: | |
Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base") | |
Settings.llm = GemmaLLMInterface(model=model, tokenizer=tokenizer) | |
#Settings.llm = llm | |
############################--------------------------------- | |
def handle_query(query_str, chathistory): | |
# Reading documents from disk | |
documents = SimpleDirectoryReader(input_files=["data/blockchainprova.txt"]).load_data() | |
# Splitting the document into chunks | |
parser = SentenceSplitter.from_defaults( | |
chunk_size=256, chunk_overlap=64, paragraph_separator="\n\n" | |
) | |
nodes = parser.get_nodes_from_documents(documents) | |
# BUILD A VECTOR STORE | |
index = VectorStoreIndex(nodes) | |
qa_prompt_str = ( | |
"Context information is below.\n" | |
"---------------------\n" | |
"{context_str}\n" | |
"---------------------\n" | |
"Given the context information and not prior knowledge, " | |
"answer the question: {query_str}\n" | |
) | |
# Text QA Prompt | |
chat_text_qa_msgs = [ | |
( | |
"system", | |
"Sei un assistente italiano di nome Ossy che risponde solo alle domande o richieste pertinenti. ", | |
), | |
("user", qa_prompt_str), | |
] | |
text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs) | |
try: | |
result = index.as_query_engine(text_qa_template=text_qa_template).query(query_str) | |
response_text = result.response | |
# Remove any unwanted tokens like <end_of_turn> | |
cleaned_result = response_text.replace("<end_of_turn>", "").strip() | |
yield cleaned_result | |
except Exception as e: | |
yield f"Error processing query: {str(e)}" | |