Spaces:
Sleeping
Sleeping
File size: 5,035 Bytes
643e1b9 aac5496 643e1b9 8bce767 643e1b9 8c678cf ac12a64 b7aed3a b8c06a5 465bc79 ecc789c 465bc79 974c8b8 1275101 b210fbe b910146 ac12a64 1275101 231b62a 643e1b9 0467f17 643e1b9 b7a41e7 aac5496 baf000f 6130d38 0467f17 baf000f 643e1b9 650c39a b277c0d d3df8fd b8c06a5 8bce767 b8c06a5 8bce767 643e1b9 708da42 b52ede2 b8c06a5 08c9e9f b8c06a5 08c9e9f 5592cea 08c9e9f 5592cea 08c9e9f 8bce767 08c9e9f 643e1b9 a5cb440 6ed7896 b8c06a5 8bce767 f7aeb1e 2c6a0aa a5cb440 3ef1210 8bce767 b8c06a5 8bce767 3ef1210 8bce767 170f218 8bce767 3ef1210 f16b8b5 3ef1210 f16b8b5 8bce767 93baf7b ed51056 86b68c0 8bce767 f7aeb1e 8bce767 74bd69e 8bce767 86b68c0 8bce767 ed51056 0328982 8238f47 643e1b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import torch
import os
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer, AutoTokenizer
from interface import GemmaLLMInterface
from llama_index.core.node_parser import SentenceSplitter
from llama_index.embeddings.instructor import InstructorEmbedding
import gradio as gr
from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, PromptTemplate, load_index_from_storage, StorageContext
from llama_index.core.node_parser import SentenceSplitter
import spaces
from huggingface_hub import login
from llama_index.core.memory import ChatMemoryBuffer
from typing import Iterator, List, Any
from llama_index.core.chat_engine import CondensePlusContextChatEngine
from llama_index.core.llms import ChatMessage, MessageRole
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
login(huggingface_token)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_id = "google/gemma-2-2b-it"
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype= torch.bfloat16 if torch.cuda.is_available() else torch.float32,
token=True)
model.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
model.eval()
# what models will be used by LlamaIndex:
Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
Settings.llm = GemmaLLMInterface()
documents_paths = {
'blockchain': 'data/blockchainprova.txt',
'metaverse': 'data/metaverso',
'payment': 'data/payment'
}
session_state = {"index": False,
"documents_loaded": False,
"document_db": None,
"original_message": None,
"clarification": False}
PERSIST_DIR = "./db"
os.makedirs(PERSIST_DIR, exist_ok=True)
############################---------------------------------
# Get the parser
parser = SentenceSplitter.from_defaults(
chunk_size=256, chunk_overlap=64, paragraph_separator="\n\n"
)
def build_index(path: str):
# Load documents from a file
documents = SimpleDirectoryReader(input_files=[path]).load_data()
# Parse the documents into nodes
nodes = parser.get_nodes_from_documents(documents)
# Build the vector store index from the nodes
index = VectorStoreIndex(nodes)
storage_context = StorageContext.from_defaults()
index.storage_context.persist(persist_dir=PERSIST_DIR)
return index
@spaces.GPU(duration=20)
def handle_query(query_str: str,
chat_history: list[tuple[str, str]]) -> Iterator[str]:
#global conversation
conversation: List[ChatMessage] = []
for user, assistant in chat_history:
conversation.extend([
ChatMessage(role=MessageRole.USER, content=user),
ChatMessage(role=MessageRole.ASSISTANT, content=assistant),
]
)
if not session_state["index"]:
matched_path = None
words = query_str.lower()
for key, path in documents_paths.items():
if key in words:
matched_path = path
break
if matched_path:
index = build_index(matched_path)
session_state["index"] = True
else: ## CHIEDI CHIARIMENTO
index = build_index("data/chiarimento.txt")
else:
# The index is already built, no need to rebuild it.
storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
index = load_index_from_storage(storage_context)
try:
memory = ChatMemoryBuffer.from_defaults(token_limit=None)
chat_engine = index.as_chat_engine(
chat_mode="condense_plus_context",
memory=memory,
similarity_top_k=4,
response_mode="tree_summarize", #Good for summarization purposes
context_prompt = (
"Sei un assistente Q&A italiano di nome Odi, che risponde solo alle domande o richieste pertinenti in modo preciso."
" Quando un utente ti chiede informazioni su di te o sul tuo creatore puoi dire che sei un assistente ricercatore creato dagli Osservatori Digitali e fornire gli argomenti di cui sei esperto."
" Ecco i documenti rilevanti per il contesto:\n"
"{context_str}"
"\nIstruzione: Usa la cronologia delle chat precedenti, o il contesto sopra, per interagire e aiutare l'utente a rispondere alla sua domanda."
),
verbose=False,
)
outputs = []
response = chat_engine.stream_chat(query_str, conversation)
#response = chat_engine.chat(query_str)
for token in response.response_gen:
#if not token.startswith("system:") and not token.startswith("user:"):
outputs.append(token)
#print(f"Generated token: {token}")
yield "".join(outputs)
except Exception as e:
yield f"Error processing query: {str(e)}"
|