gufett0 commited on
Commit
8bce767
·
1 Parent(s): a5cb440

added introductory prompt

Browse files
Files changed (1) hide show
  1. backend.py +68 -57
backend.py CHANGED
@@ -5,8 +5,7 @@ from interface import GemmaLLMInterface
5
  from llama_index.core.node_parser import SentenceSplitter
6
  from llama_index.embeddings.instructor import InstructorEmbedding
7
  import gradio as gr
8
- from llama_index.core import ChatPromptTemplate
9
- from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, PromptTemplate, load_index_from_storage
10
  from llama_index.core.node_parser import SentenceSplitter
11
  from huggingface_hub import hf_hub_download
12
  from llama_cpp import Llama
@@ -46,11 +45,15 @@ documents_paths = {
46
  'payment': 'data/payment'
47
  }
48
 
49
- session_state = {"documents_loaded": False,
 
50
  "document_db": None,
51
  "original_message": None,
52
  "clarification": False}
53
 
 
 
 
54
  ############################---------------------------------
55
 
56
  # Get the parser
@@ -66,47 +69,21 @@ def build_index(path: str):
66
  # Build the vector store index from the nodes
67
  index = VectorStoreIndex(nodes)
68
 
 
 
 
69
  return index
70
 
71
 
72
- # Global variables
73
- global_index = None
74
- global_session_state = {}
75
-
76
- def initialize_global_state():
77
- global global_index, global_session_state
78
- global_index = None
79
- global_session_state = {
80
- "documents_loaded": False,
81
- "document_db": None,
82
- "original_message": None,
83
- "clarification": False,
84
- "conversation": []
85
- }
86
-
87
- # Call this at the start of your script
88
- initialize_global_state()
89
 
90
  @spaces.GPU(duration=30)
91
  def handle_query(query_str: str,
92
- chat_history: list[tuple[str, str]],
93
- session: dict[str, Any]) -> Iterator[str]:
94
- global global_index, global_session_state
95
-
96
- # Update global session state with any new information from the passed session
97
- global_session_state.update(session)
98
 
99
- # Update conversation history
100
- for user, assistant in chat_history:
101
- global_session_state["conversation"].extend([
102
- ChatMessage(role=MessageRole.USER, content=user),
103
- ChatMessage(role=MessageRole.ASSISTANT, content=assistant),
104
- ])
105
 
106
- # Add current query to conversation
107
- global_session_state["conversation"].append(ChatMessage(role=MessageRole.USER, content=query_str))
108
 
109
- if global_index is None:
110
  matched_path = None
111
  words = query_str.lower()
112
  for key, path in documents_paths.items():
@@ -114,39 +91,73 @@ def handle_query(query_str: str,
114
  matched_path = path
115
  break
116
  if matched_path:
117
- global_index = build_index(matched_path)
118
- global_session_state["documents_loaded"] = True
119
- else:
120
- global_index = build_index("data/chiarimento.txt")
121
- global_session_state["clarification"] = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  try:
 
 
 
124
  memory = ChatMemoryBuffer.from_defaults(token_limit=None)
125
 
126
- chat_engine = global_index.as_chat_engine(
127
- chat_mode="condense_plus_context",
128
- memory=memory,
129
- similarity_top_k=4,
130
- response_mode="tree_summarize",
131
- context_prompt = (
132
- "Sei un assistente Q&A italiano di nome Odi, che risponde solo alle domande o richieste pertinenti in modo preciso."
133
- " Quando un utente ti chiede informazioni su di te o sul tuo creatore puoi dire che sei un assistente ricercatore creato dagli Osservatori Digitali e fornire gli argomenti di cui sei esperto."
134
- " Ecco i documenti rilevanti per il contesto:\n"
135
- "{context_str}"
136
- "\nIstruzione: Usa la cronologia delle chat precedenti, o il contesto sopra, per interagire e aiutare l'utente a rispondere alla sua domanda."
137
- ),
138
- verbose=False,
 
139
  )
140
 
141
- response = chat_engine.stream_chat(query_str, global_session_state["conversation"])
142
 
143
  outputs = []
 
 
144
  for token in response.response_gen:
 
 
145
  outputs.append(token)
 
146
  yield "".join(outputs)
147
-
148
- # Update the session with any changes
149
- session.update(global_session_state)
150
 
151
  except Exception as e:
152
  yield f"Error processing query: {str(e)}"
 
5
  from llama_index.core.node_parser import SentenceSplitter
6
  from llama_index.embeddings.instructor import InstructorEmbedding
7
  import gradio as gr
8
+ from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, PromptTemplate, load_index_from_storage, StorageContext
 
9
  from llama_index.core.node_parser import SentenceSplitter
10
  from huggingface_hub import hf_hub_download
11
  from llama_cpp import Llama
 
45
  'payment': 'data/payment'
46
  }
47
 
48
+ session_state = {"index": False,
49
+ "documents_loaded": False,
50
  "document_db": None,
51
  "original_message": None,
52
  "clarification": False}
53
 
54
+ PERSIST_DIR = "./db"
55
+ os.makedirs(PERSIST_DIR, exist_ok=True)
56
+
57
  ############################---------------------------------
58
 
59
  # Get the parser
 
69
  # Build the vector store index from the nodes
70
  index = VectorStoreIndex(nodes)
71
 
72
+ storage_context = StorageContext.from_defaults()
73
+ index.storage_context.persist(persist_dir=PERSIST_DIR)
74
+
75
  return index
76
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  @spaces.GPU(duration=30)
80
  def handle_query(query_str: str,
81
+ chat_history: list[tuple[str, str]]) -> Iterator[str]:
 
 
 
 
 
82
 
83
+ global conversation
 
 
 
 
 
84
 
 
 
85
 
86
+ if not session_state["index"]:
87
  matched_path = None
88
  words = query_str.lower()
89
  for key, path in documents_paths.items():
 
91
  matched_path = path
92
  break
93
  if matched_path:
94
+ index = build_index(matched_path)
95
+ session_state["index"] = True
96
+
97
+ else: ## CHIEDI CHIARIMENTO
98
+ conversation: List[ChatMessage] = []
99
+ for user, assistant in chat_history:
100
+ conversation.extend(
101
+ [
102
+ ChatMessage(role=MessageRole.USER, content=user),
103
+
104
+ ChatMessage(role=MessageRole.ASSISTANT, content=assistant),
105
+ ]
106
+ )
107
+ index = build_index("data/chiarimento.txt")
108
+
109
+
110
+
111
+ else:
112
+
113
+
114
+ # The index is already built, no need to rebuild it.
115
+ conversation: List[ChatMessage] = []
116
+ for user, assistant in chat_history:
117
+ conversation.extend(
118
+ [
119
+ ChatMessage(role=MessageRole.USER, content=user),
120
+
121
+ ChatMessage(role=MessageRole.ASSISTANT, content=assistant),
122
+ ]
123
+ )
124
+
125
+ #conversation.append( ChatMessage(role=MessageRole.USER, content=query_str))
126
+ #pass
127
 
128
  try:
129
+
130
+ storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
131
+ index = load_index_from_storage(storage_context)
132
  memory = ChatMemoryBuffer.from_defaults(token_limit=None)
133
 
134
+ chat_engine = index.as_chat_engine(
135
+ chat_mode="condense_plus_context",
136
+ memory=memory,
137
+ similarity_top_k=4,
138
+ response_mode="tree_summarize", #Good for summarization purposes
139
+
140
+ context_prompt = (
141
+ "Sei un assistente Q&A italiano di nome Odi, che risponde solo alle domande o richieste pertinenti in modo preciso."
142
+ " Quando un utente ti chiede informazioni su di te o sul tuo creatore puoi dire che sei un assistente ricercatore creato dagli Osservatori Digitali e fornire gli argomenti di cui sei esperto."
143
+ " Ecco i documenti rilevanti per il contesto:\n"
144
+ "{context_str}"
145
+ "\nIstruzione: Usa la cronologia delle chat precedenti, o il contesto sopra, per interagire e aiutare l'utente a rispondere alla sua domanda."
146
+ ),
147
+ verbose=False,
148
  )
149
 
 
150
 
151
  outputs = []
152
+ response = chat_engine.stream_chat(query_str, conversation)
153
+ #response = chat_engine.chat(query_str)
154
  for token in response.response_gen:
155
+ #if not token.startswith("system:") and not token.startswith("user:"):
156
+
157
  outputs.append(token)
158
+ #print(f"Generated token: {token}")
159
  yield "".join(outputs)
160
+
 
 
161
 
162
  except Exception as e:
163
  yield f"Error processing query: {str(e)}"