Spaces:
Sleeping
Sleeping
File size: 2,786 Bytes
643e1b9 aac5496 643e1b9 1275101 8c678cf ac12a64 1275101 b210fbe b910146 ac12a64 1275101 231b62a 643e1b9 0865501 3f367eb 643e1b9 aac5496 559bfa7 0865501 231b62a 643e1b9 0865501 643e1b9 708da42 b52ede2 643e1b9 03d2fc2 643e1b9 92eca27 708da42 a86bac6 643e1b9 b52ede2 643e1b9 b52ede2 643e1b9 b52ede2 643e1b9 1c8dd0f 8238f47 0328982 643e1b9 0328982 643e1b9 0328982 8238f47 643e1b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import torch
import os
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer, AutoTokenizer
from interface import GemmaLLMInterface
from llama_index.core.node_parser import SentenceSplitter
from llama_index.embeddings.instructor import InstructorEmbedding
import gradio as gr
from llama_index.core import ChatPromptTemplate
from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, PromptTemplate, load_index_from_storage
from llama_index.core.node_parser import SentenceSplitter
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
import spaces
from huggingface_hub import login
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
login(huggingface_token)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_id = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype= torch.bfloat16 if torch.cuda.is_available() else torch.float32,
token=True
)
model.eval()
# what models will be used by LlamaIndex:
Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
Settings.llm = GemmaLLMInterface(model=model, tokenizer=tokenizer)
#Settings.llm = llm
############################---------------------------------
# Get the parser
parser = SentenceSplitter.from_defaults(
chunk_size=256, chunk_overlap=64, paragraph_separator="\n\n"
)
@spaces.GPU(duration=20)
def handle_query(query_str, chathistory):
# build the vector
documents = SimpleDirectoryReader(input_files=["data/blockchainprova.txt"]).load_data()
nodes = parser.get_nodes_from_documents(documents)
index = VectorStoreIndex(nodes)
qa_prompt_str = (
"Context information is below.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given the context information and not prior knowledge, "
"answer the question: {query_str}\n"
)
# Text QA Prompt
chat_text_qa_msgs = [
(
"system",
"Sei un assistente italiano di nome Ossy che risponde solo alle domande o richieste pertinenti. ",
),
("user", qa_prompt_str),
]
text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
try:
result = index.as_query_engine(text_qa_template=text_qa_template, streaming=True).query(query_str)
response_text = result.response
# Remove any unwanted tokens like <end_of_turn>
cleaned_result = response_text.replace("<end_of_turn>", "").strip()
yield cleaned_result
except Exception as e:
yield f"Error processing query: {str(e)}"
|