Spaces:
Sleeping
Sleeping
import torch | |
import os | |
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer, AutoTokenizer | |
from interface import GemmaLLMInterface | |
from llama_index.core.node_parser import SentenceSplitter | |
from llama_index.embeddings.instructor import InstructorEmbedding | |
import gradio as gr | |
from llama_index.core import ChatPromptTemplate | |
from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, PromptTemplate, load_index_from_storage | |
from llama_index.core.node_parser import SentenceSplitter | |
from huggingface_hub import hf_hub_download | |
from llama_cpp import Llama | |
import spaces | |
from huggingface_hub import login | |
from llama_index.core.memory import ChatMemoryBuffer | |
from typing import Iterator | |
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") | |
login(huggingface_token) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
model_id = "google/gemma-2-2b-it" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", | |
torch_dtype= torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
token=True) | |
model.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it") | |
model.eval() | |
# what models will be used by LlamaIndex: | |
Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base") | |
Settings.llm = GemmaLLMInterface() | |
############################--------------------------------- | |
# Get the parser | |
parser = SentenceSplitter.from_defaults( | |
chunk_size=256, chunk_overlap=64, paragraph_separator="\n\n" | |
) | |
def build_index(): | |
# Load documents from a file | |
documents = SimpleDirectoryReader(input_files=["data/blockchainprova.txt"]).load_data() | |
# Parse the documents into nodes | |
nodes = parser.get_nodes_from_documents(documents) | |
# Build the vector store index from the nodes | |
index = VectorStoreIndex(nodes) | |
return index | |
def handle_query(query_str, chathistory) -> Iterator[str]: | |
index = build_index() | |
try: | |
memory = ChatMemoryBuffer.from_defaults(token_limit=1500) | |
chat_engine = index.as_chat_engine( | |
chat_mode="context", | |
memory=memory, | |
system_prompt=( | |
"Sei un assistente Q&A italiano di nome Odi, che risponde solo alle domande o richieste pertinenti in modo preciso. Hai una risposta predefinita per quando un utente ti chiede informazioni su di te o sul tuo creatore, ovvero: 'Sono un assistente ricercatore creato dagli Osservatori Digitali'." | |
), | |
) | |
outputs = [] | |
response = chat_engine.stream_chat(query_str) | |
#response = chat_engine.chat(query_str) | |
for token in response.response_gen: | |
if not token.startswith("system:") and not token.startswith("user:"): | |
outputs.append(str(token)) | |
print(f"Generated token: {token}") | |
yield "".join(outputs) | |
except Exception as e: | |
yield f"Error processing query: {str(e)}" | |