Spaces:
Sleeping
Sleeping
added introductory prompt
Browse files- backend.py +68 -57
backend.py
CHANGED
@@ -5,8 +5,7 @@ from interface import GemmaLLMInterface
|
|
5 |
from llama_index.core.node_parser import SentenceSplitter
|
6 |
from llama_index.embeddings.instructor import InstructorEmbedding
|
7 |
import gradio as gr
|
8 |
-
from llama_index.core import
|
9 |
-
from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, PromptTemplate, load_index_from_storage
|
10 |
from llama_index.core.node_parser import SentenceSplitter
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
from llama_cpp import Llama
|
@@ -46,11 +45,15 @@ documents_paths = {
|
|
46 |
'payment': 'data/payment'
|
47 |
}
|
48 |
|
49 |
-
session_state = {"
|
|
|
50 |
"document_db": None,
|
51 |
"original_message": None,
|
52 |
"clarification": False}
|
53 |
|
|
|
|
|
|
|
54 |
############################---------------------------------
|
55 |
|
56 |
# Get the parser
|
@@ -66,47 +69,21 @@ def build_index(path: str):
|
|
66 |
# Build the vector store index from the nodes
|
67 |
index = VectorStoreIndex(nodes)
|
68 |
|
|
|
|
|
|
|
69 |
return index
|
70 |
|
71 |
|
72 |
-
# Global variables
|
73 |
-
global_index = None
|
74 |
-
global_session_state = {}
|
75 |
-
|
76 |
-
def initialize_global_state():
|
77 |
-
global global_index, global_session_state
|
78 |
-
global_index = None
|
79 |
-
global_session_state = {
|
80 |
-
"documents_loaded": False,
|
81 |
-
"document_db": None,
|
82 |
-
"original_message": None,
|
83 |
-
"clarification": False,
|
84 |
-
"conversation": []
|
85 |
-
}
|
86 |
-
|
87 |
-
# Call this at the start of your script
|
88 |
-
initialize_global_state()
|
89 |
|
90 |
@spaces.GPU(duration=30)
|
91 |
def handle_query(query_str: str,
|
92 |
-
chat_history: list[tuple[str, str]]
|
93 |
-
session: dict[str, Any]) -> Iterator[str]:
|
94 |
-
global global_index, global_session_state
|
95 |
-
|
96 |
-
# Update global session state with any new information from the passed session
|
97 |
-
global_session_state.update(session)
|
98 |
|
99 |
-
|
100 |
-
for user, assistant in chat_history:
|
101 |
-
global_session_state["conversation"].extend([
|
102 |
-
ChatMessage(role=MessageRole.USER, content=user),
|
103 |
-
ChatMessage(role=MessageRole.ASSISTANT, content=assistant),
|
104 |
-
])
|
105 |
|
106 |
-
# Add current query to conversation
|
107 |
-
global_session_state["conversation"].append(ChatMessage(role=MessageRole.USER, content=query_str))
|
108 |
|
109 |
-
if
|
110 |
matched_path = None
|
111 |
words = query_str.lower()
|
112 |
for key, path in documents_paths.items():
|
@@ -114,39 +91,73 @@ def handle_query(query_str: str,
|
|
114 |
matched_path = path
|
115 |
break
|
116 |
if matched_path:
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
try:
|
|
|
|
|
|
|
124 |
memory = ChatMemoryBuffer.from_defaults(token_limit=None)
|
125 |
|
126 |
-
chat_engine =
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
|
|
139 |
)
|
140 |
|
141 |
-
response = chat_engine.stream_chat(query_str, global_session_state["conversation"])
|
142 |
|
143 |
outputs = []
|
|
|
|
|
144 |
for token in response.response_gen:
|
|
|
|
|
145 |
outputs.append(token)
|
|
|
146 |
yield "".join(outputs)
|
147 |
-
|
148 |
-
# Update the session with any changes
|
149 |
-
session.update(global_session_state)
|
150 |
|
151 |
except Exception as e:
|
152 |
yield f"Error processing query: {str(e)}"
|
|
|
5 |
from llama_index.core.node_parser import SentenceSplitter
|
6 |
from llama_index.embeddings.instructor import InstructorEmbedding
|
7 |
import gradio as gr
|
8 |
+
from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, PromptTemplate, load_index_from_storage, StorageContext
|
|
|
9 |
from llama_index.core.node_parser import SentenceSplitter
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
from llama_cpp import Llama
|
|
|
45 |
'payment': 'data/payment'
|
46 |
}
|
47 |
|
48 |
+
session_state = {"index": False,
|
49 |
+
"documents_loaded": False,
|
50 |
"document_db": None,
|
51 |
"original_message": None,
|
52 |
"clarification": False}
|
53 |
|
54 |
+
PERSIST_DIR = "./db"
|
55 |
+
os.makedirs(PERSIST_DIR, exist_ok=True)
|
56 |
+
|
57 |
############################---------------------------------
|
58 |
|
59 |
# Get the parser
|
|
|
69 |
# Build the vector store index from the nodes
|
70 |
index = VectorStoreIndex(nodes)
|
71 |
|
72 |
+
storage_context = StorageContext.from_defaults()
|
73 |
+
index.storage_context.persist(persist_dir=PERSIST_DIR)
|
74 |
+
|
75 |
return index
|
76 |
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
@spaces.GPU(duration=30)
|
80 |
def handle_query(query_str: str,
|
81 |
+
chat_history: list[tuple[str, str]]) -> Iterator[str]:
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
+
global conversation
|
|
|
|
|
|
|
|
|
|
|
84 |
|
|
|
|
|
85 |
|
86 |
+
if not session_state["index"]:
|
87 |
matched_path = None
|
88 |
words = query_str.lower()
|
89 |
for key, path in documents_paths.items():
|
|
|
91 |
matched_path = path
|
92 |
break
|
93 |
if matched_path:
|
94 |
+
index = build_index(matched_path)
|
95 |
+
session_state["index"] = True
|
96 |
+
|
97 |
+
else: ## CHIEDI CHIARIMENTO
|
98 |
+
conversation: List[ChatMessage] = []
|
99 |
+
for user, assistant in chat_history:
|
100 |
+
conversation.extend(
|
101 |
+
[
|
102 |
+
ChatMessage(role=MessageRole.USER, content=user),
|
103 |
+
|
104 |
+
ChatMessage(role=MessageRole.ASSISTANT, content=assistant),
|
105 |
+
]
|
106 |
+
)
|
107 |
+
index = build_index("data/chiarimento.txt")
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
else:
|
112 |
+
|
113 |
+
|
114 |
+
# The index is already built, no need to rebuild it.
|
115 |
+
conversation: List[ChatMessage] = []
|
116 |
+
for user, assistant in chat_history:
|
117 |
+
conversation.extend(
|
118 |
+
[
|
119 |
+
ChatMessage(role=MessageRole.USER, content=user),
|
120 |
+
|
121 |
+
ChatMessage(role=MessageRole.ASSISTANT, content=assistant),
|
122 |
+
]
|
123 |
+
)
|
124 |
+
|
125 |
+
#conversation.append( ChatMessage(role=MessageRole.USER, content=query_str))
|
126 |
+
#pass
|
127 |
|
128 |
try:
|
129 |
+
|
130 |
+
storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
|
131 |
+
index = load_index_from_storage(storage_context)
|
132 |
memory = ChatMemoryBuffer.from_defaults(token_limit=None)
|
133 |
|
134 |
+
chat_engine = index.as_chat_engine(
|
135 |
+
chat_mode="condense_plus_context",
|
136 |
+
memory=memory,
|
137 |
+
similarity_top_k=4,
|
138 |
+
response_mode="tree_summarize", #Good for summarization purposes
|
139 |
+
|
140 |
+
context_prompt = (
|
141 |
+
"Sei un assistente Q&A italiano di nome Odi, che risponde solo alle domande o richieste pertinenti in modo preciso."
|
142 |
+
" Quando un utente ti chiede informazioni su di te o sul tuo creatore puoi dire che sei un assistente ricercatore creato dagli Osservatori Digitali e fornire gli argomenti di cui sei esperto."
|
143 |
+
" Ecco i documenti rilevanti per il contesto:\n"
|
144 |
+
"{context_str}"
|
145 |
+
"\nIstruzione: Usa la cronologia delle chat precedenti, o il contesto sopra, per interagire e aiutare l'utente a rispondere alla sua domanda."
|
146 |
+
),
|
147 |
+
verbose=False,
|
148 |
)
|
149 |
|
|
|
150 |
|
151 |
outputs = []
|
152 |
+
response = chat_engine.stream_chat(query_str, conversation)
|
153 |
+
#response = chat_engine.chat(query_str)
|
154 |
for token in response.response_gen:
|
155 |
+
#if not token.startswith("system:") and not token.startswith("user:"):
|
156 |
+
|
157 |
outputs.append(token)
|
158 |
+
#print(f"Generated token: {token}")
|
159 |
yield "".join(outputs)
|
160 |
+
|
|
|
|
|
161 |
|
162 |
except Exception as e:
|
163 |
yield f"Error processing query: {str(e)}"
|