Spaces:
Build error
Build error
Commit
·
8d2f9d4
1
Parent(s):
44e73ac
Add prompt edit and api key config
Browse files- .gitignore +3 -0
- app/api/api_file.py +261 -0
- app/api/api_prompt.py +32 -0
- app/app.py +153 -0
- app/crud/process_file.py +152 -0
- app/handlers/media_handler.py +3 -3
- app/handlers/message_handler.py +7 -5
- app/handlers/webhook_handler.py +4 -2
- app/main.py +12 -4
- app/search/bm25_search.py +159 -0
- app/search/faiss_search.py +76 -0
- app/search/hybrid_search.py +247 -0
- app/search/rag_pipeline.py +147 -0
- app/services/message.py +75 -7
- app/services/search_engine.py +17 -1
- app/utils/load_env.py +2 -2
- app/utils/system_prompt.py +6 -0
- app/utils/token_counter.py +29 -0
- app/utils/tool_call_extractor.py +149 -0
- document_logs_2024-12-20.txt +1 -0
- testcode.py +138 -2
.gitignore
CHANGED
@@ -2,5 +2,8 @@
|
|
2 |
__pycache__
|
3 |
.env
|
4 |
user_media/
|
|
|
|
|
|
|
5 |
|
6 |
|
|
|
2 |
__pycache__
|
3 |
.env
|
4 |
user_media/
|
5 |
+
toolkits/
|
6 |
+
test*.py
|
7 |
+
|
8 |
|
9 |
|
app/api/api_file.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, Request, Query, status
|
2 |
+
from fastapi.responses import StreamingResponse
|
3 |
+
import os
|
4 |
+
import logging
|
5 |
+
import uuid
|
6 |
+
from datetime import datetime
|
7 |
+
|
8 |
+
from pydantic import BaseModel, Field
|
9 |
+
from typing import Optional, List, Any
|
10 |
+
from urllib.parse import urlparse
|
11 |
+
import shutil
|
12 |
+
# from app.wrapper.llm_wrapper import *
|
13 |
+
from app.crud.process_file import load_file_with_markitdown, process_uploaded_file
|
14 |
+
|
15 |
+
# Configure logging
|
16 |
+
logging.basicConfig(level=logging.INFO)
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
def is_url(path: str) -> bool:
|
21 |
+
"""
|
22 |
+
Determines if the given path is a URL.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
path (str): The path or URL to check.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
bool: True if it's a URL, False otherwise.
|
29 |
+
"""
|
30 |
+
try:
|
31 |
+
result = urlparse(path)
|
32 |
+
return all([result.scheme, result.netloc])
|
33 |
+
except Exception:
|
34 |
+
return False
|
35 |
+
|
36 |
+
|
37 |
+
file_router = APIRouter()
|
38 |
+
|
39 |
+
# Configure logging to file with date-based filenames
|
40 |
+
log_filename = f"document_logs_{datetime.now().strftime('%Y-%m-%d')}.txt"
|
41 |
+
file_handler = logging.FileHandler(log_filename)
|
42 |
+
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
43 |
+
file_handler.setFormatter(formatter)
|
44 |
+
|
45 |
+
# Create a logger for document processing
|
46 |
+
doc_logger = logging.getLogger('document_logger')
|
47 |
+
doc_logger.setLevel(logging.INFO)
|
48 |
+
doc_logger.addHandler(file_handler)
|
49 |
+
|
50 |
+
# Also configure the general logger if not already configured
|
51 |
+
logging.basicConfig(level=logging.INFO)
|
52 |
+
logger = logging.getLogger(__name__)
|
53 |
+
|
54 |
+
from app.search.rag_pipeline import RAGSystem
|
55 |
+
from sentence_transformers import SentenceTransformer
|
56 |
+
|
57 |
+
|
58 |
+
@file_router.post("/load_file_with_markdown/")
|
59 |
+
async def load_file_with_markdown(request: Request, filepaths: List[str]):
|
60 |
+
try:
|
61 |
+
# Ensure RAG system is initialized
|
62 |
+
try:
|
63 |
+
rag_system = request.app.state.rag_system
|
64 |
+
if rag_system is None:
|
65 |
+
raise AttributeError("RAG system is not initialized in app state")
|
66 |
+
except AttributeError:
|
67 |
+
logger.error("RAG system is not initialized in app state")
|
68 |
+
raise HTTPException(status_code=500, detail="RAG system not initialized in app state")
|
69 |
+
|
70 |
+
|
71 |
+
processed_files = []
|
72 |
+
pages = []
|
73 |
+
|
74 |
+
# Process each file path or URL
|
75 |
+
for path in filepaths:
|
76 |
+
if is_url(path):
|
77 |
+
logger.info(f"Processing URL: {path}")
|
78 |
+
try:
|
79 |
+
# Generate a unique UUID for the document
|
80 |
+
doc_id = str(uuid.uuid4())
|
81 |
+
|
82 |
+
# Process the URL
|
83 |
+
document = await process_uploaded_file(id=doc_id, file_path=path, rag_system=rag_system)
|
84 |
+
|
85 |
+
# Append the document details to pages
|
86 |
+
pages.append({
|
87 |
+
"metadata": {"title": document.title},
|
88 |
+
"page_content": document.text_content,
|
89 |
+
})
|
90 |
+
|
91 |
+
logger.info(f"Successfully processed URL: {path} with ID: {doc_id}")
|
92 |
+
|
93 |
+
# Log the ID and a 100-character snippet of the document
|
94 |
+
snippet = document.text_content[:100].replace('\n', ' ').replace('\r', ' ')
|
95 |
+
# Ensure 'doc_logger' is defined; if not, use 'logger' or define 'doc_logger'
|
96 |
+
doc_logger.info(f"ID: {doc_id}_{document.title}, Snippet: {snippet}")
|
97 |
+
|
98 |
+
|
99 |
+
except Exception as e:
|
100 |
+
logger.error(f"Error processing URL {path}: {str(e)}")
|
101 |
+
processed_files.append({"path": path, "status": "error", "message": str(e)})
|
102 |
+
|
103 |
+
else:
|
104 |
+
logger.info(f"Processing local file: {path}")
|
105 |
+
if os.path.exists(path):
|
106 |
+
try:
|
107 |
+
# Generate a unique UUID for the document
|
108 |
+
doc_id = str(uuid.uuid4())
|
109 |
+
|
110 |
+
# Process the local file
|
111 |
+
document = await process_uploaded_file(id=doc_id, file_path=path, rag_system=rag_system)
|
112 |
+
|
113 |
+
# Append the document details to pages
|
114 |
+
pages.append({
|
115 |
+
"metadata": {"title": document.title},
|
116 |
+
"page_content": document.text_content,
|
117 |
+
})
|
118 |
+
|
119 |
+
logger.info(f"Successfully processed file: {path} with ID: {doc_id}")
|
120 |
+
|
121 |
+
# Log the ID and a 100-character snippet of the document
|
122 |
+
snippet = document.text_content[:100].replace('\n', ' ').replace('\r', ' ')
|
123 |
+
# Ensure 'doc_logger' is defined; if not, use 'logger' or define 'doc_logger'
|
124 |
+
logger.info(f"ID: {doc_id}, Snippet: {snippet}")
|
125 |
+
|
126 |
+
except Exception as e:
|
127 |
+
logger.error(f"Error processing file {path}: {str(e)}")
|
128 |
+
processed_files.append({"path": path, "status": "error", "message": str(e)})
|
129 |
+
else:
|
130 |
+
logger.error(f"File path does not exist: {path}")
|
131 |
+
processed_files.append({"path": path, "status": "not found"})
|
132 |
+
|
133 |
+
# Get total tokens from RAG system
|
134 |
+
total_tokens = rag_system.get_total_tokens() if hasattr(rag_system, "get_total_tokens") else 0
|
135 |
+
|
136 |
+
return {
|
137 |
+
"message": "File processing completed",
|
138 |
+
"total_tokens": total_tokens,
|
139 |
+
"document_count": len(filepaths),
|
140 |
+
"pages": pages,
|
141 |
+
"errors": processed_files, # Include details about files that couldn't be processed
|
142 |
+
}
|
143 |
+
|
144 |
+
except Exception as e:
|
145 |
+
logger.exception("Unexpected error during file processing")
|
146 |
+
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
|
147 |
+
|
148 |
+
async def load_file_with_markdown_function(filepaths: List[str],
|
149 |
+
rag_system: Any):
|
150 |
+
try:
|
151 |
+
# Ensure RAG system is initialized
|
152 |
+
try:
|
153 |
+
rag_system = rag_system
|
154 |
+
except AttributeError:
|
155 |
+
logger.error("RAG system is not initialized in app state")
|
156 |
+
raise HTTPException(status_code=500, detail="RAG system not initialized in app state")
|
157 |
+
|
158 |
+
|
159 |
+
processed_files = []
|
160 |
+
pages = []
|
161 |
+
|
162 |
+
# Process each file path or URL
|
163 |
+
for path in filepaths:
|
164 |
+
if is_url(path):
|
165 |
+
logger.info(f"Processing URL: {path}")
|
166 |
+
try:
|
167 |
+
# Generate a unique UUID for the document
|
168 |
+
doc_id = str(uuid.uuid4())
|
169 |
+
|
170 |
+
# Process the URL
|
171 |
+
document = await process_uploaded_file(id=doc_id, file_path=path, rag_system=rag_system)
|
172 |
+
|
173 |
+
# Append the document details to pages
|
174 |
+
pages.append({
|
175 |
+
"metadata": {"title": document.title},
|
176 |
+
"page_content": document.text_content,
|
177 |
+
})
|
178 |
+
|
179 |
+
logger.info(f"Successfully processed URL: {path} with ID: {doc_id}")
|
180 |
+
|
181 |
+
# Log the ID and a 100-character snippet of the document
|
182 |
+
snippet = document.text_content[:100].replace('\n', ' ').replace('\r', ' ')
|
183 |
+
# Ensure 'doc_logger' is defined; if not, use 'logger' or define 'doc_logger'
|
184 |
+
doc_logger(f"ID: {doc_id}, Snippet: {snippet}")
|
185 |
+
logger.info(f"ID: {doc_id}, Snippet: {snippet}")
|
186 |
+
|
187 |
+
except Exception as e:
|
188 |
+
logger.error(f"Error processing URL {path}: {str(e)}")
|
189 |
+
processed_files.append({"path": path, "status": "error", "message": str(e)})
|
190 |
+
|
191 |
+
else:
|
192 |
+
logger.info(f"Processing local file: {path}")
|
193 |
+
if os.path.exists(path):
|
194 |
+
try:
|
195 |
+
# Generate a unique UUID for the document
|
196 |
+
doc_id = str(uuid.uuid4())
|
197 |
+
|
198 |
+
# Process the local file
|
199 |
+
document = await process_uploaded_file(id=doc_id, file_path=path, rag_system=rag_system)
|
200 |
+
|
201 |
+
# Append the document details to pages
|
202 |
+
pages.append({
|
203 |
+
"metadata": {"title": document.title},
|
204 |
+
"page_content": document.text_content,
|
205 |
+
})
|
206 |
+
|
207 |
+
logger.info(f"Successfully processed file: {path} with ID: {doc_id}")
|
208 |
+
|
209 |
+
# Log the ID and a 100-character snippet of the document
|
210 |
+
snippet = document.text_content[:100].replace('\n', ' ').replace('\r', ' ')
|
211 |
+
# Ensure 'doc_logger' is defined; if not, use 'logger' or define 'doc_logger'
|
212 |
+
logger.info(f"ID: {doc_id}, Snippet: {snippet}")
|
213 |
+
|
214 |
+
except Exception as e:
|
215 |
+
logger.error(f"Error processing file {path}: {str(e)}")
|
216 |
+
processed_files.append({"path": path, "status": "error", "message": str(e)})
|
217 |
+
else:
|
218 |
+
logger.error(f"File path does not exist: {path}")
|
219 |
+
processed_files.append({"path": path, "status": "not found"})
|
220 |
+
|
221 |
+
# Get total tokens from RAG system
|
222 |
+
total_tokens = rag_system.get_total_tokens() if hasattr(rag_system, "get_total_tokens") else 0
|
223 |
+
|
224 |
+
return {
|
225 |
+
"message": "File processing completed",
|
226 |
+
"total_tokens": total_tokens,
|
227 |
+
"document_count": len(filepaths),
|
228 |
+
"pages": pages,
|
229 |
+
"errors": processed_files, # Include details about files that couldn't be processed
|
230 |
+
}
|
231 |
+
|
232 |
+
except Exception as e:
|
233 |
+
logger.exception("Unexpected error during file processing")
|
234 |
+
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
|
235 |
+
|
236 |
+
@file_router.get("/document_exists/{doc_id}", status_code=status.HTTP_200_OK)
|
237 |
+
async def document_exists(request: Request, doc_id: str):
|
238 |
+
try:
|
239 |
+
rag_system = request.app.state.rag_system
|
240 |
+
except AttributeError:
|
241 |
+
logger.error("RAG system is not initialized in app state")
|
242 |
+
raise HTTPException(status_code=500, detail="RAG system not initialized in app state")
|
243 |
+
|
244 |
+
exists = doc_id in rag_system.doc_ids
|
245 |
+
return {"document_id": doc_id, "exists": exists}
|
246 |
+
|
247 |
+
@file_router.delete("/delete_document/{doc_id}", status_code=status.HTTP_200_OK)
|
248 |
+
async def delete_document(request: Request, doc_id: str):
|
249 |
+
try:
|
250 |
+
rag_system = request.app.state.rag_system
|
251 |
+
except AttributeError:
|
252 |
+
logger.error("RAG system is not initialized in app state")
|
253 |
+
raise HTTPException(status_code=500, detail="RAG system not initialized in app state")
|
254 |
+
|
255 |
+
try:
|
256 |
+
rag_system.delete_document(doc_id)
|
257 |
+
logger.info(f"Deleted document with ID: {doc_id}")
|
258 |
+
return {"message": f"Document with ID {doc_id} has been deleted."}
|
259 |
+
except Exception as e:
|
260 |
+
logger.error(f"Error deleting document with ID {doc_id}: {str(e)}")
|
261 |
+
raise HTTPException(status_code=500, detail=f"Failed to delete document: {str(e)}")
|
app/api/api_prompt.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException, APIRouter
|
2 |
+
from pydantic import BaseModel
|
3 |
+
from app.utils.system_prompt import system_prompt, agentic_prompt
|
4 |
+
|
5 |
+
prompt_router = APIRouter()
|
6 |
+
|
7 |
+
# Define a model for the prompts
|
8 |
+
class Prompt(BaseModel):
|
9 |
+
system_prompt: str = None
|
10 |
+
agentic_prompt: str = None
|
11 |
+
|
12 |
+
# API endpoint to get the current prompts
|
13 |
+
@prompt_router.get("/prompts")
|
14 |
+
def get_prompts():
|
15 |
+
return {
|
16 |
+
"system_prompt": system_prompt,
|
17 |
+
"agentic_prompt": agentic_prompt,
|
18 |
+
}
|
19 |
+
|
20 |
+
# API endpoint to update the prompts
|
21 |
+
@prompt_router.put("/prompts")
|
22 |
+
def update_prompts(prompts: Prompt):
|
23 |
+
global system_prompt, agentic_prompt
|
24 |
+
if prompts.system_prompt is not None:
|
25 |
+
system_prompt = prompts.system_prompt
|
26 |
+
if prompts.agentic_prompt is not None:
|
27 |
+
agentic_prompt = prompts.agentic_prompt
|
28 |
+
return {
|
29 |
+
"message": "Prompts updated successfully",
|
30 |
+
"system_prompt": system_prompt,
|
31 |
+
"agentic_prompt": agentic_prompt,
|
32 |
+
}
|
app/app.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, Request, status
|
2 |
+
from fastapi.responses import JSONResponse
|
3 |
+
from fastapi.responses import Response
|
4 |
+
from fastapi.exceptions import HTTPException
|
5 |
+
from slowapi import Limiter, _rate_limit_exceeded_handler
|
6 |
+
from slowapi.errors import RateLimitExceeded
|
7 |
+
from slowapi.util import get_remote_address
|
8 |
+
from slowapi.middleware import SlowAPIMiddleware
|
9 |
+
from typing import Dict, List
|
10 |
+
from prometheus_client import Counter, Histogram, start_http_server
|
11 |
+
from pydantic import BaseModel, ValidationError
|
12 |
+
from app.services.message import generate_reply, send_reply
|
13 |
+
import logging
|
14 |
+
from datetime import datetime
|
15 |
+
from sentence_transformers import SentenceTransformer
|
16 |
+
|
17 |
+
from contextlib import asynccontextmanager
|
18 |
+
# from app.db.database import create_indexes, init_db
|
19 |
+
from app.services.webhook_handler import verify_webhook
|
20 |
+
from app.handlers.message_handler import MessageHandler
|
21 |
+
from app.handlers.webhook_handler import WebhookHandler
|
22 |
+
from app.handlers.media_handler import WhatsAppMediaHandler
|
23 |
+
from app.services.cache import MessageCache
|
24 |
+
from app.services.chat_manager import ChatManager
|
25 |
+
from app.api.api_file import file_router
|
26 |
+
from app.utils.load_env import ACCESS_TOKEN
|
27 |
+
from app.search.rag_pipeline import RAGSystem
|
28 |
+
|
29 |
+
# Configure logging
|
30 |
+
logging.basicConfig(
|
31 |
+
level=logging.INFO,
|
32 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
33 |
+
)
|
34 |
+
|
35 |
+
logger = logging.getLogger(__name__)
|
36 |
+
|
37 |
+
# Initialize handlers at startup
|
38 |
+
message_handler = None
|
39 |
+
webhook_handler = None
|
40 |
+
|
41 |
+
|
42 |
+
async def setup_message_handler():
|
43 |
+
logger = logging.getLogger(__name__)
|
44 |
+
message_cache = MessageCache()
|
45 |
+
chat_manager = ChatManager()
|
46 |
+
media_handler = WhatsAppMediaHandler()
|
47 |
+
|
48 |
+
return MessageHandler(
|
49 |
+
message_cache=message_cache,
|
50 |
+
chat_manager=chat_manager,
|
51 |
+
media_handler=media_handler,
|
52 |
+
logger=logger
|
53 |
+
)
|
54 |
+
|
55 |
+
async def setup_rag_system():
|
56 |
+
embedding_model = SentenceTransformer('all-MiniLM-L6-v2') # Replace with your model if different
|
57 |
+
rag_system = RAGSystem(embedding_model)
|
58 |
+
|
59 |
+
return rag_system
|
60 |
+
|
61 |
+
# Initialize FastAPI app
|
62 |
+
@asynccontextmanager
|
63 |
+
async def lifespan(app: FastAPI):
|
64 |
+
|
65 |
+
|
66 |
+
try:
|
67 |
+
# await init_db()
|
68 |
+
|
69 |
+
logger.info("Connected to the MongoDB database!")
|
70 |
+
|
71 |
+
rag_system = await setup_rag_system()
|
72 |
+
|
73 |
+
app.state.rag_system = rag_system
|
74 |
+
|
75 |
+
global message_handler, webhook_handler
|
76 |
+
message_handler = await setup_message_handler()
|
77 |
+
webhook_handler = WebhookHandler(message_handler)
|
78 |
+
# collections = app.database.list_collection_names()
|
79 |
+
# print(f"Collections in {db_name}: {collections}")
|
80 |
+
yield
|
81 |
+
except Exception as e:
|
82 |
+
logger.error(e)
|
83 |
+
|
84 |
+
# Initialize Limiter and Prometheus Metrics
|
85 |
+
limiter = Limiter(key_func=get_remote_address)
|
86 |
+
app = FastAPI(lifespan=lifespan)
|
87 |
+
app.state.limiter = limiter
|
88 |
+
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
89 |
+
|
90 |
+
# Add SlowAPI Middleware
|
91 |
+
app.add_middleware(SlowAPIMiddleware)
|
92 |
+
|
93 |
+
# app.include_router(users.router, prefix="/users", tags=["Users"])
|
94 |
+
app.include_router(file_router, prefix="/file_load", tags=["File Load"])
|
95 |
+
|
96 |
+
# Prometheus metrics
|
97 |
+
webhook_requests = Counter('webhook_requests_total', 'Total webhook requests')
|
98 |
+
webhook_processing_time = Histogram('webhook_processing_seconds', 'Time spent processing webhook')
|
99 |
+
|
100 |
+
# Pseudocode: You may have a database or environment variable to validate API keys
|
101 |
+
VALID_API_KEYS = {"!@#$%^"}
|
102 |
+
class WebhookPayload(BaseModel):
|
103 |
+
entry: List[Dict]
|
104 |
+
|
105 |
+
@app.post("/api/v1/messages")
|
106 |
+
@limiter.limit("20/minute")
|
107 |
+
async def process_message(request: Request):
|
108 |
+
try:
|
109 |
+
# Validate developer’s API key
|
110 |
+
api_key = request.headers.get("Authorization")
|
111 |
+
if not api_key or not api_key.startswith("Bearer "):
|
112 |
+
raise HTTPException(status_code=401, detail="Missing or invalid API key")
|
113 |
+
|
114 |
+
api_key_value = api_key.replace("Bearer ", "")
|
115 |
+
if api_key_value not in VALID_API_KEYS:
|
116 |
+
raise HTTPException(status_code=403, detail="Forbidden")
|
117 |
+
|
118 |
+
payload = await request.json()
|
119 |
+
|
120 |
+
# Extract needed credentials from query params or request body
|
121 |
+
# e.g., whatsapp_token, verify_token, llm_api_key, llm_model
|
122 |
+
whatsapp_token = request.query_params.get("whatsapp_token")
|
123 |
+
whatsapp_url = request.query_params.get("whatsapp_url")
|
124 |
+
gemini_api = request.query_params.get("gemini_api")
|
125 |
+
llm_model = request.query_params.get("cx_code")
|
126 |
+
|
127 |
+
print(f"payload: {payload}")
|
128 |
+
response = await webhook_handler.process_webhook(
|
129 |
+
payload=payload,
|
130 |
+
whatsapp_token=whatsapp_token,
|
131 |
+
whatsapp_url=whatsapp_url,
|
132 |
+
gemini_api=gemini_api,
|
133 |
+
)
|
134 |
+
|
135 |
+
return JSONResponse(
|
136 |
+
content=response.__dict__,
|
137 |
+
status_code=status.HTTP_200_OK
|
138 |
+
)
|
139 |
+
|
140 |
+
except ValidationError as ve:
|
141 |
+
logger.error(f"Validation error: {ve}")
|
142 |
+
return JSONResponse(
|
143 |
+
content={"status": "error", "detail": ve.errors()},
|
144 |
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY
|
145 |
+
)
|
146 |
+
except Exception as e:
|
147 |
+
logger.error(f"Unexpected error: {str(e)}")
|
148 |
+
return JSONResponse(
|
149 |
+
content={"status": "error", "detail": str(e)},
|
150 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
151 |
+
)
|
152 |
+
|
153 |
+
app.get("/webhook")(verify_webhook)
|
app/crud/process_file.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app/crud.py
|
2 |
+
from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, CSVLoader, UnstructuredExcelLoader
|
3 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
4 |
+
from sqlalchemy.future import select
|
5 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
6 |
+
from markitdown import MarkItDown
|
7 |
+
import os
|
8 |
+
import logging
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
from typing import List, Optional
|
13 |
+
# from app.db.models.docs import *
|
14 |
+
# from app.schemas.schemas import DocumentCreate, DocumentUpdate
|
15 |
+
|
16 |
+
# Configure logging
|
17 |
+
logging.basicConfig(level=logging.INFO)
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
async def load_file_with_markitdown(file_path:str, llm_client:str=None, model:str=None):
|
21 |
+
|
22 |
+
if llm_client and model:
|
23 |
+
markitdown = MarkItDown(llm_client, model)
|
24 |
+
documents = markitdown.convert(file_path)
|
25 |
+
else:
|
26 |
+
markitdown = MarkItDown()
|
27 |
+
documents = markitdown.convert(file_path)
|
28 |
+
|
29 |
+
return documents
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
async def load_pdf_with_langchain(file_path):
|
34 |
+
"""
|
35 |
+
Loads and extracts text from a PDF file using LangChain's PyPDFLoader.
|
36 |
+
|
37 |
+
Parameters:
|
38 |
+
file_path (str): Path to the PDF file.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
List[Document]: A list of LangChain Document objects with metadata.
|
42 |
+
"""
|
43 |
+
|
44 |
+
loader = PyPDFLoader(file_path, extract_images=True)
|
45 |
+
|
46 |
+
documents = loader.load()
|
47 |
+
|
48 |
+
return documents # Returns a list of Document objects
|
49 |
+
|
50 |
+
async def load_file_with_langchain(file_path: str):
|
51 |
+
"""
|
52 |
+
Loads and extracts text from a PDF or DOCX file using LangChain's appropriate loader.
|
53 |
+
|
54 |
+
Parameters:
|
55 |
+
file_path (str): Path to the file (PDF or DOCX).
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
List[Document]: A list of LangChain Document objects with metadata.
|
59 |
+
"""
|
60 |
+
# Determine the file extension
|
61 |
+
_, file_extension = os.path.splitext(file_path)
|
62 |
+
|
63 |
+
# Choose the loader based on file extension
|
64 |
+
if file_extension.lower() == '.pdf':
|
65 |
+
loader = PyPDFLoader(file_path)
|
66 |
+
elif file_extension.lower() == '.docx':
|
67 |
+
loader = Docx2txtLoader(file_path)
|
68 |
+
elif file_extension.lower() == '.csv':
|
69 |
+
loader = CSVLoader(file_path)
|
70 |
+
elif file_extension.lower() == '.xlsx':
|
71 |
+
loader = UnstructuredExcelLoader(file_path)
|
72 |
+
else:
|
73 |
+
raise ValueError("Unsupported file format. Please provide a PDF or DOCX file.")
|
74 |
+
|
75 |
+
# Load the documents
|
76 |
+
documents = loader.load()
|
77 |
+
|
78 |
+
return documents
|
79 |
+
|
80 |
+
async def split_documents(documents, chunk_size=10000, chunk_overlap=1000):
|
81 |
+
"""
|
82 |
+
Splits documents into smaller chunks with overlap.
|
83 |
+
|
84 |
+
Parameters:
|
85 |
+
documents (List[Document]): List of LangChain Document objects.
|
86 |
+
chunk_size (int): The maximum size of each chunk.
|
87 |
+
chunk_overlap (int): The number of characters to overlap between chunks.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
List[Document]: List of chunked Document objects.
|
91 |
+
"""
|
92 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
93 |
+
chunk_size=chunk_size,
|
94 |
+
chunk_overlap=chunk_overlap,
|
95 |
+
)
|
96 |
+
split_docs = text_splitter.split_documents(documents)
|
97 |
+
return split_docs
|
98 |
+
|
99 |
+
async def process_uploaded_file(
|
100 |
+
id, file_path,
|
101 |
+
rag_system=None,
|
102 |
+
llm_client=None,
|
103 |
+
llm_model=None
|
104 |
+
):
|
105 |
+
|
106 |
+
try:
|
107 |
+
# Load the document using LangChain
|
108 |
+
documents = await load_file_with_markitdown(file_path, llm_client=llm_client, model=llm_model)
|
109 |
+
logger.info(f"Loaded document: {file_path}")
|
110 |
+
|
111 |
+
# Concatenate all pages to get the full document text for context generation
|
112 |
+
# whole_document_content = "\n".join([doc.page_content for doc in documents])
|
113 |
+
|
114 |
+
except Exception as e:
|
115 |
+
logger.error(f"Failed to load document {file_path}: {e}")
|
116 |
+
raise RuntimeError(f"Error loading document: {file_path}") from e
|
117 |
+
|
118 |
+
# # Generate context for each chunk if llm is provided
|
119 |
+
# if llm:
|
120 |
+
# for doc in split_docs:
|
121 |
+
# try:
|
122 |
+
# context = await llm.generate_context(doc, whole_document_content=whole_document_content)
|
123 |
+
# # Add context to the beginning of the page content
|
124 |
+
# doc.page_content = f"{context.replace('<|eot_id|>', '')}\n\n{doc.page_content}"
|
125 |
+
# logger.info(f"Context generated and added for chunk {split_docs.index(doc)}")
|
126 |
+
# except Exception as e:
|
127 |
+
# logger.error(f"Failed to generate context for chunk {split_docs.index(doc)}: {e}")
|
128 |
+
# raise RuntimeError(f"Error generating context for chunk {split_docs.index(doc)}") from e
|
129 |
+
|
130 |
+
# Add to RAG system if rag_system is provided and load_only is False
|
131 |
+
if rag_system:
|
132 |
+
try:
|
133 |
+
rag_system.add_document(doc_id = f"{id}_{documents.title}", text = documents.text_content)
|
134 |
+
|
135 |
+
print(f"doc_id: {id}_{documents.title}")
|
136 |
+
print(f"content: {documents.text_content}")
|
137 |
+
|
138 |
+
# print(f"New Page Content: {doc.page_content}")
|
139 |
+
logger.info(f"Document chunks successfully added to RAG system for file {file_path}")
|
140 |
+
|
141 |
+
except Exception as e:
|
142 |
+
logger.error(f"Failed to add document chunks to RAG system for {file_path}: {e}")
|
143 |
+
raise RuntimeError(f"Error adding document to RAG system: {file_path}") from e
|
144 |
+
else:
|
145 |
+
logger.info(f"Loaded document {file_path}, but not added to RAG system")
|
146 |
+
|
147 |
+
return documents
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
|
152 |
+
|
app/handlers/media_handler.py
CHANGED
@@ -7,9 +7,9 @@ logger = logging.getLogger(__name__)
|
|
7 |
|
8 |
class MediaHandler(ABC):
|
9 |
@abstractmethod
|
10 |
-
async def download(self, media_id: str,
|
11 |
pass
|
12 |
|
13 |
class WhatsAppMediaHandler(MediaHandler):
|
14 |
-
async def download(self, media_id: str,
|
15 |
-
return await download_whatsapp_media(media_id,
|
|
|
7 |
|
8 |
class MediaHandler(ABC):
|
9 |
@abstractmethod
|
10 |
+
async def download(self, media_id: str, whatsapp_token: str, file_path: str) -> str:
|
11 |
pass
|
12 |
|
13 |
class WhatsAppMediaHandler(MediaHandler):
|
14 |
+
async def download(self, media_id: str, whatsapp_token: str, file_path: str) -> str:
|
15 |
+
return await download_whatsapp_media(media_id, whatsapp_token, file_path)
|
app/handlers/message_handler.py
CHANGED
@@ -26,7 +26,7 @@ class MessageHandler:
|
|
26 |
self.media_handler = media_handler
|
27 |
self.logger = logger
|
28 |
|
29 |
-
async def handle(self, raw_message: dict,
|
30 |
try:
|
31 |
# Parse message
|
32 |
message = MessageParser.parse(raw_message)
|
@@ -36,7 +36,7 @@ class MessageHandler:
|
|
36 |
return {"status": "duplicate", "message_id": message.id}
|
37 |
|
38 |
# Download media
|
39 |
-
media_paths = await self._process_media(message,
|
40 |
|
41 |
self.chat_manager.initialize_chat(message.sender_id)
|
42 |
|
@@ -45,7 +45,9 @@ class MessageHandler:
|
|
45 |
result = await process_message_with_llm(
|
46 |
message.sender_id,
|
47 |
message.content,
|
48 |
-
self.chat_manager.get_chat_history(message.sender_id),
|
|
|
|
|
49 |
**media_paths
|
50 |
)
|
51 |
|
@@ -60,7 +62,7 @@ class MessageHandler:
|
|
60 |
except Exception as e:
|
61 |
return {"status": "error", "message_id": raw_message.get("id"), "error": str(e)}
|
62 |
|
63 |
-
async def _process_media(self, message: Message,
|
64 |
media_paths = {
|
65 |
"image_file_path": None,
|
66 |
"doc_path": None,
|
@@ -74,7 +76,7 @@ class MessageHandler:
|
|
74 |
self.logger.info(f"Processing {media_type.value}: {content.file_path}")
|
75 |
file_path = await self.media_handler.download(
|
76 |
content.id,
|
77 |
-
|
78 |
content.file_path
|
79 |
)
|
80 |
self.logger.info(f"{media_type.value} file_path: {file_path}")
|
|
|
26 |
self.media_handler = media_handler
|
27 |
self.logger = logger
|
28 |
|
29 |
+
async def handle(self, raw_message: dict, whatsapp_token: str, whatsapp_url:str,gemini_api:str) -> dict:
|
30 |
try:
|
31 |
# Parse message
|
32 |
message = MessageParser.parse(raw_message)
|
|
|
36 |
return {"status": "duplicate", "message_id": message.id}
|
37 |
|
38 |
# Download media
|
39 |
+
media_paths = await self._process_media(message, whatsapp_token)
|
40 |
|
41 |
self.chat_manager.initialize_chat(message.sender_id)
|
42 |
|
|
|
45 |
result = await process_message_with_llm(
|
46 |
message.sender_id,
|
47 |
message.content,
|
48 |
+
self.chat_manager.get_chat_history(message.sender_id),
|
49 |
+
whatsapp_token=whatsapp_token,
|
50 |
+
whatsapp_url=whatsapp_url,
|
51 |
**media_paths
|
52 |
)
|
53 |
|
|
|
62 |
except Exception as e:
|
63 |
return {"status": "error", "message_id": raw_message.get("id"), "error": str(e)}
|
64 |
|
65 |
+
async def _process_media(self, message: Message, whatsapp_token: str) -> Dict[str, Optional[str]]:
|
66 |
media_paths = {
|
67 |
"image_file_path": None,
|
68 |
"doc_path": None,
|
|
|
76 |
self.logger.info(f"Processing {media_type.value}: {content.file_path}")
|
77 |
file_path = await self.media_handler.download(
|
78 |
content.id,
|
79 |
+
whatsapp_token,
|
80 |
content.file_path
|
81 |
)
|
82 |
self.logger.info(f"{media_type.value} file_path: {file_path}")
|
app/handlers/webhook_handler.py
CHANGED
@@ -18,7 +18,7 @@ class WebhookHandler:
|
|
18 |
self.message_handler = message_handler
|
19 |
self.logger = logging.getLogger(__name__)
|
20 |
|
21 |
-
async def process_webhook(self, payload: dict,
|
22 |
request_id = f"req_{int(time.time()*1000)}"
|
23 |
results = []
|
24 |
|
@@ -37,7 +37,9 @@ class WebhookHandler:
|
|
37 |
self.logger.info(f"Processing message: {message}")
|
38 |
response = await self.message_handler.handle(
|
39 |
raw_message=message,
|
40 |
-
|
|
|
|
|
41 |
)
|
42 |
results.append(response)
|
43 |
|
|
|
18 |
self.message_handler = message_handler
|
19 |
self.logger = logging.getLogger(__name__)
|
20 |
|
21 |
+
async def process_webhook(self, payload: dict, whatsapp_token: str, whatsapp_url:str,gemini_api:str) -> WebhookResponse:
|
22 |
request_id = f"req_{int(time.time()*1000)}"
|
23 |
results = []
|
24 |
|
|
|
37 |
self.logger.info(f"Processing message: {message}")
|
38 |
response = await self.message_handler.handle(
|
39 |
raw_message=message,
|
40 |
+
whatsapp_token=whatsapp_token,
|
41 |
+
whatsapp_url=whatsapp_url,
|
42 |
+
gemini_api=gemini_api,
|
43 |
)
|
44 |
results.append(response)
|
45 |
|
app/main.py
CHANGED
@@ -21,7 +21,7 @@ from app.handlers.webhook_handler import WebhookHandler
|
|
21 |
from app.handlers.media_handler import WhatsAppMediaHandler
|
22 |
from app.services.cache import MessageCache
|
23 |
from app.services.chat_manager import ChatManager
|
24 |
-
|
25 |
from app.utils.load_env import ACCESS_TOKEN
|
26 |
|
27 |
# Configure logging
|
@@ -79,7 +79,7 @@ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
|
79 |
app.add_middleware(SlowAPIMiddleware)
|
80 |
|
81 |
# app.include_router(users.router, prefix="/users", tags=["Users"])
|
82 |
-
|
83 |
|
84 |
# Prometheus metrics
|
85 |
webhook_requests = Counter('webhook_requests_total', 'Total webhook requests')
|
@@ -105,11 +105,19 @@ async def webhook(request: Request):
|
|
105 |
# Process the webhook payload here
|
106 |
# For example:
|
107 |
# results = process_webhook_entries(validated_payload.entry)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
response = await webhook_handler.process_webhook(
|
109 |
payload=payload,
|
110 |
-
|
|
|
|
|
111 |
)
|
112 |
-
|
113 |
return JSONResponse(
|
114 |
content=response.__dict__,
|
115 |
status_code=status.HTTP_200_OK
|
|
|
21 |
from app.handlers.media_handler import WhatsAppMediaHandler
|
22 |
from app.services.cache import MessageCache
|
23 |
from app.services.chat_manager import ChatManager
|
24 |
+
from app.api.api_prompt import prompt_router
|
25 |
from app.utils.load_env import ACCESS_TOKEN
|
26 |
|
27 |
# Configure logging
|
|
|
79 |
app.add_middleware(SlowAPIMiddleware)
|
80 |
|
81 |
# app.include_router(users.router, prefix="/users", tags=["Users"])
|
82 |
+
app.include_router(prompt_router, prefix="/prompts", tags=["Prompts"])
|
83 |
|
84 |
# Prometheus metrics
|
85 |
webhook_requests = Counter('webhook_requests_total', 'Total webhook requests')
|
|
|
105 |
# Process the webhook payload here
|
106 |
# For example:
|
107 |
# results = process_webhook_entries(validated_payload.entry)
|
108 |
+
# e.g., whatsapp_token, verify_token, llm_api_key, llm_model
|
109 |
+
whatsapp_token = request.query_params.get("whatsapp_token")
|
110 |
+
whatsapp_url = request.query_params.get("whatsapp_url")
|
111 |
+
gemini_api = request.query_params.get("gemini_api")
|
112 |
+
llm_model = request.query_params.get("cx_code")
|
113 |
+
|
114 |
+
print(f"payload: {payload}")
|
115 |
response = await webhook_handler.process_webhook(
|
116 |
payload=payload,
|
117 |
+
whatsapp_token=whatsapp_token,
|
118 |
+
whatsapp_url=whatsapp_url,
|
119 |
+
gemini_api=gemini_api,
|
120 |
)
|
|
|
121 |
return JSONResponse(
|
122 |
content=response.__dict__,
|
123 |
status_code=status.HTTP_200_OK
|
app/search/bm25_search.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# bm25_search.py
|
2 |
+
import asyncio
|
3 |
+
from rank_bm25 import BM25Okapi
|
4 |
+
import nltk
|
5 |
+
import string
|
6 |
+
from typing import List, Set, Optional
|
7 |
+
from nltk.corpus import stopwords
|
8 |
+
from nltk.stem import WordNetLemmatizer
|
9 |
+
|
10 |
+
|
11 |
+
def download_nltk_resources():
|
12 |
+
"""
|
13 |
+
Downloads required NLTK resources synchronously.
|
14 |
+
"""
|
15 |
+
resources = ['punkt', 'stopwords', 'wordnet', 'omw-1.4']
|
16 |
+
for resource in resources:
|
17 |
+
try:
|
18 |
+
nltk.download(resource, quiet=True)
|
19 |
+
except Exception as e:
|
20 |
+
print(f"Error downloading {resource}: {str(e)}")
|
21 |
+
|
22 |
+
class BM25_search:
|
23 |
+
# Class variable to track if resources have been downloaded
|
24 |
+
nltk_resources_downloaded = False
|
25 |
+
|
26 |
+
def __init__(self, remove_stopwords: bool = True, perform_lemmatization: bool = False):
|
27 |
+
"""
|
28 |
+
Initializes the BM25search.
|
29 |
+
|
30 |
+
Parameters:
|
31 |
+
- remove_stopwords (bool): Whether to remove stopwords during preprocessing.
|
32 |
+
- perform_lemmatization (bool): Whether to perform lemmatization on tokens.
|
33 |
+
"""
|
34 |
+
# Ensure NLTK resources are downloaded only once
|
35 |
+
if not BM25_search.nltk_resources_downloaded:
|
36 |
+
download_nltk_resources()
|
37 |
+
BM25_search.nltk_resources_downloaded = True # Mark as downloaded
|
38 |
+
|
39 |
+
self.documents: List[str] = []
|
40 |
+
self.doc_ids: List[str] = []
|
41 |
+
self.tokenized_docs: List[List[str]] = []
|
42 |
+
self.bm25: Optional[BM25Okapi] = None
|
43 |
+
self.remove_stopwords = remove_stopwords
|
44 |
+
self.perform_lemmatization = perform_lemmatization
|
45 |
+
self.stop_words: Set[str] = set(stopwords.words('english')) if remove_stopwords else set()
|
46 |
+
self.lemmatizer = WordNetLemmatizer() if perform_lemmatization else None
|
47 |
+
|
48 |
+
def preprocess(self, text: str) -> List[str]:
|
49 |
+
"""
|
50 |
+
Preprocesses the input text by lowercasing, removing punctuation,
|
51 |
+
tokenizing, removing stopwords, and optionally lemmatizing.
|
52 |
+
"""
|
53 |
+
text = text.lower().translate(str.maketrans('', '', string.punctuation))
|
54 |
+
tokens = nltk.word_tokenize(text)
|
55 |
+
if self.remove_stopwords:
|
56 |
+
tokens = [token for token in tokens if token not in self.stop_words]
|
57 |
+
if self.perform_lemmatization and self.lemmatizer:
|
58 |
+
tokens = [self.lemmatizer.lemmatize(token) for token in tokens]
|
59 |
+
return tokens
|
60 |
+
|
61 |
+
def add_document(self, doc_id: str, new_doc: str) -> None:
|
62 |
+
"""
|
63 |
+
Adds a new document to the corpus and updates the BM25 index.
|
64 |
+
"""
|
65 |
+
processed_tokens = self.preprocess(new_doc)
|
66 |
+
|
67 |
+
self.documents.append(new_doc)
|
68 |
+
self.doc_ids.append(doc_id)
|
69 |
+
self.tokenized_docs.append(processed_tokens)
|
70 |
+
# Ensure update_bm25 is awaited if required in async context
|
71 |
+
self.update_bm25()
|
72 |
+
print(f"Added document ID: {doc_id}")
|
73 |
+
|
74 |
+
async def remove_document(self, index: int) -> None:
|
75 |
+
"""
|
76 |
+
Removes a document from the corpus based on its index and updates the BM25 index.
|
77 |
+
"""
|
78 |
+
if 0 <= index < len(self.documents):
|
79 |
+
removed_doc_id = self.doc_ids[index]
|
80 |
+
del self.documents[index]
|
81 |
+
del self.doc_ids[index]
|
82 |
+
del self.tokenized_docs[index]
|
83 |
+
self.update_bm25()
|
84 |
+
print(f"Removed document ID: {removed_doc_id}")
|
85 |
+
else:
|
86 |
+
print(f"Index {index} is out of bounds.")
|
87 |
+
|
88 |
+
def update_bm25(self) -> None:
|
89 |
+
"""
|
90 |
+
Updates the BM25 index based on the current tokenized documents.
|
91 |
+
"""
|
92 |
+
if self.tokenized_docs:
|
93 |
+
self.bm25 = BM25Okapi(self.tokenized_docs)
|
94 |
+
print("BM25 index has been initialized.")
|
95 |
+
else:
|
96 |
+
print("No documents to initialize BM25.")
|
97 |
+
|
98 |
+
|
99 |
+
def get_scores(self, query: str) -> List[float]:
|
100 |
+
"""
|
101 |
+
Computes BM25 scores for all documents based on the given query.
|
102 |
+
"""
|
103 |
+
processed_query = self.preprocess(query)
|
104 |
+
print(f"Tokenized Query: {processed_query}")
|
105 |
+
|
106 |
+
if self.bm25:
|
107 |
+
return self.bm25.get_scores(processed_query)
|
108 |
+
else:
|
109 |
+
print("BM25 is not initialized.")
|
110 |
+
return []
|
111 |
+
|
112 |
+
def get_top_n_docs(self, query: str, n: int = 5) -> List[str]:
|
113 |
+
"""
|
114 |
+
Returns the top N documents for a given query.
|
115 |
+
"""
|
116 |
+
processed_query = self.preprocess(query)
|
117 |
+
if self.bm25:
|
118 |
+
return self.bm25.get_top_n(processed_query, self.documents, n)
|
119 |
+
else:
|
120 |
+
print("initialized.")
|
121 |
+
return []
|
122 |
+
|
123 |
+
def clear_documents(self) -> None:
|
124 |
+
"""
|
125 |
+
Clears all documents from the BM25 index.
|
126 |
+
"""
|
127 |
+
self.documents = []
|
128 |
+
self.doc_ids = []
|
129 |
+
self.tokenized_docs = []
|
130 |
+
self.bm25 = None # Reset BM25 index
|
131 |
+
print("BM25 documents cleared and index reset.")
|
132 |
+
|
133 |
+
def get_document(self, doc_id: str) -> str:
|
134 |
+
"""
|
135 |
+
Retrieves a document by its document ID.
|
136 |
+
|
137 |
+
Parameters:
|
138 |
+
- doc_id (str): The ID of the document to retrieve.
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
- str: The document text if found, otherwise an empty string.
|
142 |
+
"""
|
143 |
+
try:
|
144 |
+
index = self.doc_ids.index(doc_id)
|
145 |
+
return self.documents[index]
|
146 |
+
except ValueError:
|
147 |
+
print(f"Document ID {doc_id} not found.")
|
148 |
+
return ""
|
149 |
+
|
150 |
+
|
151 |
+
async def initialize_bm25_search(remove_stopwords: bool = True, perform_lemmatization: bool = False) -> BM25_search:
|
152 |
+
"""
|
153 |
+
Initializes the BM25search with proper NLTK resource downloading.
|
154 |
+
"""
|
155 |
+
loop = asyncio.get_running_loop()
|
156 |
+
await loop.run_in_executor(None, download_nltk_resources)
|
157 |
+
return BM25_search(remove_stopwords, perform_lemmatization)
|
158 |
+
|
159 |
+
|
app/search/faiss_search.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# faiss_wrapper.py
|
2 |
+
import faiss
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
class FAISS_search:
|
6 |
+
def __init__(self, embedding_model):
|
7 |
+
self.documents = []
|
8 |
+
self.doc_ids = []
|
9 |
+
self.embedding_model = embedding_model
|
10 |
+
self.dimension = len(embedding_model.encode("embedding"))
|
11 |
+
self.index = faiss.IndexIDMap(faiss.IndexFlatL2(self.dimension))
|
12 |
+
|
13 |
+
def add_document(self, doc_id, new_doc):
|
14 |
+
self.documents.append(new_doc)
|
15 |
+
self.doc_ids.append(doc_id)
|
16 |
+
# Encode and add document with its index as ID
|
17 |
+
embedding = self.embedding_model.encode([new_doc], convert_to_numpy=True).astype('float32')
|
18 |
+
|
19 |
+
if embedding.size == 0:
|
20 |
+
print("No documents to add to FAISS index.")
|
21 |
+
return
|
22 |
+
|
23 |
+
idx = len(self.documents) - 1
|
24 |
+
id_array = np.array([idx]).astype('int64')
|
25 |
+
self.index.add_with_ids(embedding, id_array)
|
26 |
+
|
27 |
+
def remove_document(self, index):
|
28 |
+
if 0 <= index < len(self.documents):
|
29 |
+
del self.documents[index]
|
30 |
+
del self.doc_ids[index]
|
31 |
+
# Rebuild the index
|
32 |
+
self.build_index()
|
33 |
+
else:
|
34 |
+
print(f"Index {index} is out of bounds.")
|
35 |
+
|
36 |
+
def build_index(self):
|
37 |
+
embeddings = self.embedding_model.encode(self.documents, convert_to_numpy=True).astype('float32')
|
38 |
+
idx_array = np.arange(len(self.documents)).astype('int64')
|
39 |
+
self.index = faiss.IndexIDMap(faiss.IndexFlatL2(self.dimension))
|
40 |
+
self.index.add_with_ids(embeddings, idx_array)
|
41 |
+
|
42 |
+
def search(self, query, k):
|
43 |
+
if self.index.ntotal == 0:
|
44 |
+
# No documents in the index
|
45 |
+
print("FAISS index is empty. No results can be returned.")
|
46 |
+
return np.array([]), np.array([]) # Return empty arrays for distances and indices
|
47 |
+
query_embedding = self.embedding_model.encode([query], convert_to_numpy=True).astype('float32')
|
48 |
+
distances, indices = self.index.search(query_embedding, k)
|
49 |
+
return distances[0], indices[0]
|
50 |
+
|
51 |
+
def clear_documents(self) -> None:
|
52 |
+
"""
|
53 |
+
Clears all documents from the FAISS index.
|
54 |
+
"""
|
55 |
+
self.documents = []
|
56 |
+
self.doc_ids = []
|
57 |
+
# Reset the FAISS index
|
58 |
+
self.index = faiss.IndexIDMap(faiss.IndexFlatL2(self.dimension))
|
59 |
+
print("FAISS documents cleared and index reset.")
|
60 |
+
|
61 |
+
def get_document(self, doc_id: str) -> str:
|
62 |
+
"""
|
63 |
+
Retrieves a document by its document ID.
|
64 |
+
|
65 |
+
Parameters:
|
66 |
+
- doc_id (str): The ID of the document to retrieve.
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
- str: The document text if found, otherwise an empty string.
|
70 |
+
"""
|
71 |
+
try:
|
72 |
+
index = self.doc_ids.index(doc_id)
|
73 |
+
return self.documents[index]
|
74 |
+
except ValueError:
|
75 |
+
print(f"Document ID {doc_id} not found.")
|
76 |
+
return ""
|
app/search/hybrid_search.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import logging, torch
|
3 |
+
from sklearn.preprocessing import MinMaxScaler
|
4 |
+
from sentence_transformers import CrossEncoder
|
5 |
+
# from FlagEmbedding import FlagReranker
|
6 |
+
|
7 |
+
|
8 |
+
class Hybrid_search:
|
9 |
+
def __init__(self, bm25_search, faiss_search, reranker_model_name="BAAI/bge-reranker-v2-gemma", initial_bm25_weight=0.5):
|
10 |
+
self.bm25_search = bm25_search
|
11 |
+
self.faiss_search = faiss_search
|
12 |
+
self.bm25_weight = initial_bm25_weight
|
13 |
+
# self.reranker = FlagReranker(reranker_model_name, use_fp16=True)
|
14 |
+
self.logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
async def advanced_search(self, query, keywords, top_n=5, threshold=0.53, prefixes=None):
|
17 |
+
# Dynamic BM25 weighting
|
18 |
+
self._dynamic_weighting(len(query.split()))
|
19 |
+
keywords = f"{' '.join(keywords)}"
|
20 |
+
self.logger.info(f"Query: {query}")
|
21 |
+
self.logger.info(f"Keywords: {keywords}")
|
22 |
+
|
23 |
+
# Get BM25 scores and doc_ids
|
24 |
+
bm25_scores, bm25_doc_ids = self._get_bm25_results(keywords, top_n = top_n)
|
25 |
+
# self.logger.info(f"BM25 Scores: {bm25_scores}, BM25 doc_ids: {bm25_doc_ids}")
|
26 |
+
# Get FAISS distances, indices, and doc_ids
|
27 |
+
faiss_distances, faiss_indices, faiss_doc_ids = self._get_faiss_results(query)
|
28 |
+
try:
|
29 |
+
faiss_distances, indices, faiss_doc_ids = self._get_faiss_results(query, top_n = top_n)
|
30 |
+
# for dist, idx, doc_id in zip(faiss_distances, indices, faiss_doc_ids):
|
31 |
+
# print(f"Distance: {dist:.4f}, Index: {idx}, Doc ID: {doc_id}")
|
32 |
+
except Exception as e:
|
33 |
+
self.logger.error(f"Search failed: {str(e)}")
|
34 |
+
# Map doc_ids to scores
|
35 |
+
bm25_scores_dict, faiss_scores_dict = self._map_scores_to_doc_ids(
|
36 |
+
bm25_doc_ids, bm25_scores, faiss_doc_ids, faiss_distances
|
37 |
+
)
|
38 |
+
# Create a unified set of doc IDs
|
39 |
+
all_doc_ids = sorted(set(bm25_doc_ids).union(faiss_doc_ids))
|
40 |
+
# print(f"All doc_ids: {all_doc_ids}, BM25 doc_ids: {bm25_doc_ids}, FAISS doc_ids: {faiss_doc_ids}")
|
41 |
+
|
42 |
+
# Filter doc_ids based on prefixes
|
43 |
+
filtered_doc_ids = self._filter_doc_ids_by_prefixes(all_doc_ids, prefixes)
|
44 |
+
# self.logger.info(f"Filtered doc_ids: {filtered_doc_ids}")
|
45 |
+
|
46 |
+
if not filtered_doc_ids:
|
47 |
+
self.logger.info("No documents match the prefixes.")
|
48 |
+
return []
|
49 |
+
|
50 |
+
# Prepare score lists
|
51 |
+
filtered_bm25_scores, filtered_faiss_scores = self._get_filtered_scores(
|
52 |
+
filtered_doc_ids, bm25_scores_dict, faiss_scores_dict
|
53 |
+
)
|
54 |
+
# self.logger.info(f"Filtered BM25 scores: {filtered_bm25_scores}")
|
55 |
+
# self.logger.info(f"Filtered FAISS scores: {filtered_faiss_scores}")
|
56 |
+
|
57 |
+
|
58 |
+
# Normalize scores
|
59 |
+
bm25_scores_normalized, faiss_scores_normalized = self._normalize_scores(
|
60 |
+
filtered_bm25_scores, filtered_faiss_scores
|
61 |
+
)
|
62 |
+
|
63 |
+
# Calculate hybrid scores
|
64 |
+
hybrid_scores = self._calculate_hybrid_scores(bm25_scores_normalized, faiss_scores_normalized)
|
65 |
+
|
66 |
+
# Display hybrid scores
|
67 |
+
for idx, doc_id in enumerate(filtered_doc_ids):
|
68 |
+
print(f"Hybrid Score: {hybrid_scores[idx]:.4f}, Doc ID: {doc_id}")
|
69 |
+
|
70 |
+
# Apply threshold and get top_n results
|
71 |
+
results = self._get_top_n_results(filtered_doc_ids, hybrid_scores, top_n, threshold)
|
72 |
+
# self.logger.info(f"Results before reranking: {results}")
|
73 |
+
|
74 |
+
# If results exist, apply re-ranking
|
75 |
+
# if results:
|
76 |
+
# re_ranked_results = self._rerank_results(query, results)
|
77 |
+
# self.logger.info(f"Results after reranking: {re_ranked_results}")
|
78 |
+
# return re_ranked_results
|
79 |
+
|
80 |
+
return results
|
81 |
+
|
82 |
+
|
83 |
+
def _dynamic_weighting(self, query_length):
|
84 |
+
if query_length <= 5:
|
85 |
+
self.bm25_weight = 0.7
|
86 |
+
else:
|
87 |
+
self.bm25_weight = 0.5
|
88 |
+
self.logger.info(f"Dynamic BM25 weight set to: {self.bm25_weight}")
|
89 |
+
|
90 |
+
def _get_bm25_results(self, keywords, top_n:int = None):
|
91 |
+
# Get BM25 scores
|
92 |
+
bm25_scores = np.array(self.bm25_search.get_scores(keywords))
|
93 |
+
bm25_doc_ids = np.array(self.bm25_search.doc_ids) # Assuming doc_ids is a list of document IDs
|
94 |
+
|
95 |
+
# Log the scores and IDs before filtering
|
96 |
+
# self.logger.info(f"BM25 scores: {bm25_scores}")
|
97 |
+
# self.logger.info(f"BM25 doc_ids: {bm25_doc_ids}")
|
98 |
+
|
99 |
+
# Get the top k indices based on BM25 scores
|
100 |
+
top_k_indices = np.argsort(bm25_scores)[-top_n:][::-1]
|
101 |
+
|
102 |
+
# Retrieve top k scores and corresponding document IDs
|
103 |
+
top_k_scores = bm25_scores[top_k_indices]
|
104 |
+
top_k_doc_ids = bm25_doc_ids[top_k_indices]
|
105 |
+
|
106 |
+
# Return top k scores and document IDs
|
107 |
+
return top_k_scores, top_k_doc_ids
|
108 |
+
|
109 |
+
def _get_faiss_results(self, query, top_n: int = None) -> tuple[np.ndarray, np.ndarray, list[str]]:
|
110 |
+
|
111 |
+
try:
|
112 |
+
# If top_k is not specified, use all documents
|
113 |
+
if top_n is None:
|
114 |
+
top_n = len(self.faiss_search.doc_ids)
|
115 |
+
|
116 |
+
# Use the search's search method which handles the embedding
|
117 |
+
distances, indices = self.faiss_search.search(query, k=top_n)
|
118 |
+
|
119 |
+
if len(distances) == 0 or len(indices) == 0:
|
120 |
+
# Handle case where FAISS returns empty results
|
121 |
+
self.logger.info("FAISS search returned no results.")
|
122 |
+
return np.array([]), np.array([]), []
|
123 |
+
|
124 |
+
# Filter out invalid indices (-1)
|
125 |
+
valid_mask = indices != -1
|
126 |
+
filtered_distances = distances[valid_mask]
|
127 |
+
filtered_indices = indices[valid_mask]
|
128 |
+
|
129 |
+
# Map indices to doc_ids
|
130 |
+
doc_ids = [self.faiss_search.doc_ids[idx] for idx in filtered_indices
|
131 |
+
if 0 <= idx < len(self.faiss_search.doc_ids)]
|
132 |
+
|
133 |
+
# self.logger.info(f"FAISS distances: {filtered_distances}")
|
134 |
+
# self.logger.info(f"FAISS indices: {filtered_indices}")
|
135 |
+
# self.logger.info(f"FAISS doc_ids: {doc_ids}")
|
136 |
+
|
137 |
+
return filtered_distances, filtered_indices, doc_ids
|
138 |
+
|
139 |
+
except Exception as e:
|
140 |
+
self.logger.error(f"Error in FAISS search: {str(e)}")
|
141 |
+
raise
|
142 |
+
|
143 |
+
def _map_scores_to_doc_ids(self, bm25_doc_ids, bm25_scores, faiss_doc_ids, faiss_scores):
|
144 |
+
bm25_scores_dict = dict(zip(bm25_doc_ids, bm25_scores))
|
145 |
+
faiss_scores_dict = dict(zip(faiss_doc_ids, faiss_scores))
|
146 |
+
# self.logger.info(f"BM25 scores dict: {bm25_scores_dict}")
|
147 |
+
# self.logger.info(f"FAISS scores dict: {faiss_scores_dict}")
|
148 |
+
return bm25_scores_dict, faiss_scores_dict
|
149 |
+
|
150 |
+
def _filter_doc_ids_by_prefixes(self, all_doc_ids, prefixes):
|
151 |
+
if prefixes:
|
152 |
+
filtered_doc_ids = [
|
153 |
+
doc_id
|
154 |
+
for doc_id in all_doc_ids
|
155 |
+
if any(doc_id.startswith(prefix) for prefix in prefixes)
|
156 |
+
]
|
157 |
+
else:
|
158 |
+
filtered_doc_ids = list(all_doc_ids)
|
159 |
+
return filtered_doc_ids
|
160 |
+
|
161 |
+
def _get_filtered_scores(self, filtered_doc_ids, bm25_scores_dict, faiss_scores_dict):
|
162 |
+
# Initialize lists to hold scores in the unified doc ID order
|
163 |
+
bm25_aligned_scores = []
|
164 |
+
faiss_aligned_scores = []
|
165 |
+
|
166 |
+
# Populate aligned score lists, filling missing scores with neutral values
|
167 |
+
for doc_id in filtered_doc_ids:
|
168 |
+
bm25_aligned_scores.append(bm25_scores_dict.get(doc_id, 0)) # Use 0 if not found in BM25
|
169 |
+
faiss_aligned_scores.append(faiss_scores_dict.get(doc_id, max(faiss_scores_dict.values()) + 1)) # Use a high distance if not found in FAISS
|
170 |
+
|
171 |
+
# Invert the FAISS scores
|
172 |
+
faiss_aligned_scores = [1 / score if score != 0 else 0 for score in faiss_aligned_scores]
|
173 |
+
|
174 |
+
return bm25_aligned_scores, faiss_aligned_scores
|
175 |
+
|
176 |
+
def _normalize_scores(self, filtered_bm25_scores, filtered_faiss_scores):
|
177 |
+
scaler_bm25 = MinMaxScaler()
|
178 |
+
bm25_scores_normalized = self._normalize_array(filtered_bm25_scores, scaler_bm25)
|
179 |
+
|
180 |
+
scaler_faiss = MinMaxScaler()
|
181 |
+
faiss_scores_normalized = self._normalize_array(filtered_faiss_scores, scaler_faiss)
|
182 |
+
|
183 |
+
# self.logger.info(f"Normalized BM25 scores: {bm25_scores_normalized}")
|
184 |
+
# self.logger.info(f"Normalized FAISS scores: {faiss_scores_normalized}")
|
185 |
+
return bm25_scores_normalized, faiss_scores_normalized
|
186 |
+
|
187 |
+
def _normalize_array(self, scores, scaler):
|
188 |
+
scores_array = np.array(scores)
|
189 |
+
if np.ptp(scores_array) > 0:
|
190 |
+
normalized_scores = scaler.fit_transform(scores_array.reshape(-1, 1)).flatten()
|
191 |
+
else:
|
192 |
+
# Handle identical scores with a fallback to uniform 0.5
|
193 |
+
normalized_scores = np.full_like(scores_array, 0.5, dtype=float)
|
194 |
+
return normalized_scores
|
195 |
+
|
196 |
+
def _calculate_hybrid_scores(self, bm25_scores_normalized, faiss_scores_normalized):
|
197 |
+
hybrid_scores = self.bm25_weight * bm25_scores_normalized + (1 - self.bm25_weight) * faiss_scores_normalized
|
198 |
+
# self.logger.info(f"Hybrid scores: {hybrid_scores}")
|
199 |
+
return hybrid_scores
|
200 |
+
|
201 |
+
def _get_top_n_results(self, filtered_doc_ids, hybrid_scores, top_n, threshold):
|
202 |
+
hybrid_scores = np.array(hybrid_scores)
|
203 |
+
threshold_indices = np.where(hybrid_scores >= threshold)[0]
|
204 |
+
if len(threshold_indices) == 0:
|
205 |
+
self.logger.info("No documents meet the threshold.")
|
206 |
+
return []
|
207 |
+
|
208 |
+
sorted_indices = threshold_indices[np.argsort(hybrid_scores[threshold_indices])[::-1]]
|
209 |
+
top_indices = sorted_indices[:top_n]
|
210 |
+
|
211 |
+
results = [(filtered_doc_ids[idx], hybrid_scores[idx]) for idx in top_indices]
|
212 |
+
self.logger.info(f"Top {top_n} results: {results}")
|
213 |
+
return results
|
214 |
+
|
215 |
+
def _rerank_results(self, query, results):
|
216 |
+
"""
|
217 |
+
Re-rank the retrieved documents using FlagReranker with normalized scores.
|
218 |
+
|
219 |
+
Parameters:
|
220 |
+
- query (str): The search query.
|
221 |
+
- results (List[Tuple[str, float]]): A list of (doc_id, score) tuples.
|
222 |
+
|
223 |
+
Returns:
|
224 |
+
- List[Tuple[str, float]]: Re-ranked list of (doc_id, score) tuples with normalized scores.
|
225 |
+
"""
|
226 |
+
# Prepare input for the re-ranker
|
227 |
+
document_texts = [self.bm25_search.get_document(doc_id) for doc_id, _ in results]
|
228 |
+
doc_ids = [doc_id for doc_id, _ in results]
|
229 |
+
|
230 |
+
# Generate pairwise scores using the FlagReranker
|
231 |
+
rerank_inputs = [[query, doc] for doc in document_texts]
|
232 |
+
with torch.no_grad():
|
233 |
+
rerank_scores = self.reranker.compute_score(rerank_inputs, normalize=True)
|
234 |
+
|
235 |
+
# rerank_scores = self.reranker.compute_score(rerank_inputs, normalize=True)
|
236 |
+
|
237 |
+
# Combine doc_ids with normalized re-rank scores and sort by scores
|
238 |
+
reranked_results = sorted(
|
239 |
+
zip(doc_ids, rerank_scores),
|
240 |
+
key=lambda x: x[1],
|
241 |
+
reverse=True
|
242 |
+
)
|
243 |
+
|
244 |
+
# Log and return results
|
245 |
+
# self.logger.info(f"Re-ranked results with normalized scores: {reranked_results}")
|
246 |
+
return reranked_results
|
247 |
+
|
app/search/rag_pipeline.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# rag_pipeline.py
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
import pickle
|
6 |
+
import os
|
7 |
+
import logging
|
8 |
+
import asyncio
|
9 |
+
|
10 |
+
from app.search.bm25_search import BM25_search
|
11 |
+
from app.search.faiss_search import FAISS_search
|
12 |
+
from app.search.hybrid_search import Hybrid_search
|
13 |
+
from app.utils.token_counter import TokenCounter
|
14 |
+
|
15 |
+
|
16 |
+
# Configure logging
|
17 |
+
logging.basicConfig(level=logging.INFO)
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
# rag.py
|
21 |
+
class RAGSystem:
|
22 |
+
def __init__(self, embedding_model):
|
23 |
+
self.token_counter = TokenCounter()
|
24 |
+
self.documents = []
|
25 |
+
self.doc_ids = []
|
26 |
+
self.results = []
|
27 |
+
self.meta_data = []
|
28 |
+
self.embedding_model = embedding_model
|
29 |
+
self.bm25_wrapper = BM25_search()
|
30 |
+
self.faiss_wrapper = FAISS_search(embedding_model)
|
31 |
+
self.hybrid_search = Hybrid_search(self.bm25_wrapper, self.faiss_wrapper)
|
32 |
+
|
33 |
+
def add_document(self, doc_id, text, metadata=None):
|
34 |
+
self.token_counter.add_document(doc_id, text)
|
35 |
+
self.doc_ids.append(doc_id)
|
36 |
+
self.documents.append(text)
|
37 |
+
self.meta_data.append(metadata)
|
38 |
+
self.bm25_wrapper.add_document(doc_id, text)
|
39 |
+
self.faiss_wrapper.add_document(doc_id, text)
|
40 |
+
|
41 |
+
def delete_document(self, doc_id):
|
42 |
+
try:
|
43 |
+
index = self.doc_ids.index(doc_id)
|
44 |
+
del self.doc_ids[index]
|
45 |
+
del self.documents[index]
|
46 |
+
self.bm25_wrapper.remove_document(index)
|
47 |
+
self.faiss_wrapper.remove_document(index)
|
48 |
+
self.token_counter.remove_document(doc_id)
|
49 |
+
except ValueError:
|
50 |
+
logging.warning(f"Document ID {doc_id} not found.")
|
51 |
+
|
52 |
+
async def adv_query(self, query_text, keywords, top_k=15, prefixes=None):
|
53 |
+
results = await self.hybrid_search.advanced_search(
|
54 |
+
query_text,
|
55 |
+
keywords=keywords,
|
56 |
+
top_n=top_k,
|
57 |
+
threshold=0.43,
|
58 |
+
prefixes=prefixes
|
59 |
+
)
|
60 |
+
retrieved_docs = []
|
61 |
+
if results:
|
62 |
+
seen_docs = set()
|
63 |
+
for doc_id, score in results:
|
64 |
+
if doc_id not in seen_docs:
|
65 |
+
# Check if the doc_id exists in self.doc_ids
|
66 |
+
if doc_id not in self.doc_ids:
|
67 |
+
logger.error(f"doc_id {doc_id} not found in self.doc_ids")
|
68 |
+
seen_docs.add(doc_id)
|
69 |
+
|
70 |
+
# Fetch the index of the document
|
71 |
+
try:
|
72 |
+
index = self.doc_ids.index(doc_id)
|
73 |
+
except ValueError as e:
|
74 |
+
logger.error(f"Error finding index for doc_id {doc_id}: {e}")
|
75 |
+
continue
|
76 |
+
|
77 |
+
# Validate index range
|
78 |
+
if index >= len(self.documents) or index >= len(self.meta_data):
|
79 |
+
logger.error(f"Index {index} out of range for documents or metadata")
|
80 |
+
continue
|
81 |
+
|
82 |
+
doc = self.documents[index]
|
83 |
+
|
84 |
+
# meta_data = self.meta_data[index]
|
85 |
+
# Extract the file name and page number
|
86 |
+
# file_name = meta_data['source'].split('/')[-1] # Extracts 'POJK 31 - 2018.pdf'
|
87 |
+
# page_number = meta_data.get('page', 'unknown')
|
88 |
+
# url = meta_data['source']
|
89 |
+
# file_name = meta_data.get('source', 'unknown_source').split('/')[-1] # Safe extraction
|
90 |
+
# page_number = meta_data.get('page', 'unknown') # Default to 'unknown' if 'page' is missing
|
91 |
+
# url = meta_data.get('source', 'unknown_url') # Default URL fallback
|
92 |
+
|
93 |
+
# logger.info(f"file_name: {file_name}, page_number: {page_number}, url: {url}")
|
94 |
+
|
95 |
+
# Format as a single string
|
96 |
+
# content_string = f"'{file_name}', 'page': {page_number}"
|
97 |
+
# doc_name = f"{file_name}"
|
98 |
+
|
99 |
+
self.results.append(doc)
|
100 |
+
retrieved_docs.append({"text": doc})
|
101 |
+
return retrieved_docs
|
102 |
+
else:
|
103 |
+
return [{"name": "No relevant documents found.", "text": None}]
|
104 |
+
|
105 |
+
def get_total_tokens(self):
|
106 |
+
return self.token_counter.get_total_tokens()
|
107 |
+
def get_context(self):
|
108 |
+
context = "\n".join(self.results)
|
109 |
+
return context
|
110 |
+
|
111 |
+
def save_state(self, path):
|
112 |
+
# Save doc_ids, documents, and token counter state
|
113 |
+
with open(f"{path}_state.pkl", 'wb') as f:
|
114 |
+
pickle.dump({
|
115 |
+
"doc_ids": self.doc_ids,
|
116 |
+
"documents": self.documents,
|
117 |
+
"meta_data": self.meta_data,
|
118 |
+
"token_counts": self.token_counter.doc_tokens
|
119 |
+
}, f)
|
120 |
+
|
121 |
+
def load_state(self, path):
|
122 |
+
if os.path.exists(f"{path}_state.pkl"):
|
123 |
+
with open(f"{path}_state.pkl", 'rb') as f:
|
124 |
+
state_data = pickle.load(f)
|
125 |
+
self.doc_ids = state_data["doc_ids"]
|
126 |
+
self.documents = state_data["documents"]
|
127 |
+
self.meta_data = state_data["meta_data"]
|
128 |
+
self.token_counter.doc_tokens = state_data["token_counts"]
|
129 |
+
|
130 |
+
# Clear and rebuild BM25 and FAISS
|
131 |
+
self.bm25_wrapper.clear_documents()
|
132 |
+
self.faiss_wrapper.clear_documents()
|
133 |
+
for doc_id, document in zip(self.doc_ids, self.documents):
|
134 |
+
self.bm25_wrapper.add_document(doc_id, document)
|
135 |
+
self.faiss_wrapper.add_document(doc_id, document)
|
136 |
+
|
137 |
+
self.token_counter.total_tokens = sum(self.token_counter.doc_tokens.values())
|
138 |
+
logging.info("System state loaded successfully with documents and indices rebuilt.")
|
139 |
+
else:
|
140 |
+
logging.info("No previous state found, initializing fresh state.")
|
141 |
+
self.doc_ids = []
|
142 |
+
self.documents = []
|
143 |
+
self.meta_data = [] # Reset meta_data
|
144 |
+
self.token_counter = TokenCounter()
|
145 |
+
self.bm25_wrapper = BM25_search()
|
146 |
+
self.faiss_wrapper = FAISS_search(self.embedding_model)
|
147 |
+
self.hybrid_search = Hybrid_search(self.bm25_wrapper, self.faiss_wrapper)
|
app/services/message.py
CHANGED
@@ -6,7 +6,9 @@ from datetime import datetime
|
|
6 |
import logging
|
7 |
import asyncio
|
8 |
from openai import AsyncOpenAI
|
|
|
9 |
import google.generativeai as genai
|
|
|
10 |
import PIL.Image
|
11 |
from typing import List, Dict, Any, Optional
|
12 |
|
@@ -17,6 +19,29 @@ from app.services.search_engine import google_search
|
|
17 |
# Load environment variables
|
18 |
load_dotenv()
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
genai.configure(api_key=GEMNI_API)
|
21 |
client = AsyncOpenAI(api_key = OPENAI_API)
|
22 |
# Configure logging
|
@@ -31,9 +56,9 @@ if not WHATSAPP_API_URL or not ACCESS_TOKEN:
|
|
31 |
logger.warning("Environment variables for WHATSAPP_API_URL or ACCESS_TOKEN are not set!")
|
32 |
|
33 |
# Helper function to send a reply
|
34 |
-
async def send_reply(to: str, body: str) -> Dict[str, Any]:
|
35 |
headers = {
|
36 |
-
"Authorization": f"Bearer {
|
37 |
"Content-Type": "application/json"
|
38 |
}
|
39 |
data = {
|
@@ -46,7 +71,7 @@ async def send_reply(to: str, body: str) -> Dict[str, Any]:
|
|
46 |
}
|
47 |
|
48 |
async with httpx.AsyncClient() as client:
|
49 |
-
response = await client.post(
|
50 |
|
51 |
if response.status_code != 200:
|
52 |
error_detail = response.json()
|
@@ -74,7 +99,9 @@ async def generate_reply(sender: str, content: str, timestamp: int) -> str:
|
|
74 |
async def process_message_with_llm(
|
75 |
sender_id: str,
|
76 |
content: str,
|
77 |
-
history: List[Dict[str, str]],
|
|
|
|
|
78 |
image_file_path: Optional[str] = None,
|
79 |
doc_path: Optional[str] = None,
|
80 |
video_file_path: Optional[str] = None,
|
@@ -92,7 +119,7 @@ async def process_message_with_llm(
|
|
92 |
)
|
93 |
logger.info(f"Generated reply: {generated_reply}")
|
94 |
|
95 |
-
response = await send_reply(sender_id, generated_reply)
|
96 |
# return generated_reply
|
97 |
return generated_reply
|
98 |
except Exception as e:
|
@@ -140,14 +167,55 @@ async def generate_response_from_gemini(
|
|
140 |
pass # Placeholder for video processing logic
|
141 |
|
142 |
# Send the user's message
|
143 |
-
response = await chat.send_message_async(content
|
|
|
144 |
return response.text
|
145 |
|
146 |
except Exception as e:
|
147 |
logger.error("Error in generate_response_from_gemini:", exc_info=True)
|
148 |
return "Sorry, I couldn't generate a response at this time."
|
149 |
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
# Process message with retry logic
|
153 |
# async def process_message_with_retry(
|
|
|
6 |
import logging
|
7 |
import asyncio
|
8 |
from openai import AsyncOpenAI
|
9 |
+
import json
|
10 |
import google.generativeai as genai
|
11 |
+
|
12 |
import PIL.Image
|
13 |
from typing import List, Dict, Any, Optional
|
14 |
|
|
|
19 |
# Load environment variables
|
20 |
load_dotenv()
|
21 |
|
22 |
+
# Define function specifications for Gemini
|
23 |
+
function_declarations = [
|
24 |
+
{
|
25 |
+
"name": "google_search",
|
26 |
+
"description": "Perform a Google search and retrieve search results",
|
27 |
+
"parameters": {
|
28 |
+
"type": "object",
|
29 |
+
"properties": {
|
30 |
+
"query": {
|
31 |
+
"type": "string",
|
32 |
+
"description": "The search query to perform"
|
33 |
+
},
|
34 |
+
"num_results": {
|
35 |
+
"type": "string",
|
36 |
+
"description": "Number of search results to retrieve (1-10)",
|
37 |
+
"default": "3"
|
38 |
+
}
|
39 |
+
},
|
40 |
+
"required": ["query"]
|
41 |
+
}
|
42 |
+
}
|
43 |
+
]
|
44 |
+
|
45 |
genai.configure(api_key=GEMNI_API)
|
46 |
client = AsyncOpenAI(api_key = OPENAI_API)
|
47 |
# Configure logging
|
|
|
56 |
logger.warning("Environment variables for WHATSAPP_API_URL or ACCESS_TOKEN are not set!")
|
57 |
|
58 |
# Helper function to send a reply
|
59 |
+
async def send_reply(to: str, body: str, whatsapp_token: str, whatsapp_url:str) -> Dict[str, Any]:
|
60 |
headers = {
|
61 |
+
"Authorization": f"Bearer {whatsapp_token}",
|
62 |
"Content-Type": "application/json"
|
63 |
}
|
64 |
data = {
|
|
|
71 |
}
|
72 |
|
73 |
async with httpx.AsyncClient() as client:
|
74 |
+
response = await client.post(whatsapp_url, json=data, headers=headers)
|
75 |
|
76 |
if response.status_code != 200:
|
77 |
error_detail = response.json()
|
|
|
99 |
async def process_message_with_llm(
|
100 |
sender_id: str,
|
101 |
content: str,
|
102 |
+
history: List[Dict[str, str]],
|
103 |
+
whatsapp_token: str,
|
104 |
+
whatsapp_url:str,
|
105 |
image_file_path: Optional[str] = None,
|
106 |
doc_path: Optional[str] = None,
|
107 |
video_file_path: Optional[str] = None,
|
|
|
119 |
)
|
120 |
logger.info(f"Generated reply: {generated_reply}")
|
121 |
|
122 |
+
response = await send_reply(sender_id, generated_reply, whatsapp_token, whatsapp_url)
|
123 |
# return generated_reply
|
124 |
return generated_reply
|
125 |
except Exception as e:
|
|
|
167 |
pass # Placeholder for video processing logic
|
168 |
|
169 |
# Send the user's message
|
170 |
+
response = await chat.send_message_async(content)
|
171 |
+
# response = await handle_function_call(response)
|
172 |
return response.text
|
173 |
|
174 |
except Exception as e:
|
175 |
logger.error("Error in generate_response_from_gemini:", exc_info=True)
|
176 |
return "Sorry, I couldn't generate a response at this time."
|
177 |
|
178 |
+
async def handle_function_call(chat):
|
179 |
+
"""
|
180 |
+
Handle function calls from the Gemini API.
|
181 |
+
|
182 |
+
Args:
|
183 |
+
chat (ChatSession): The current chat session.
|
184 |
+
|
185 |
+
Returns:
|
186 |
+
The response after resolving function calls.
|
187 |
+
"""
|
188 |
+
# Continue the conversation and handle any function calls
|
189 |
+
while True:
|
190 |
+
response = chat.send_message_async(chat.history[-1])
|
191 |
+
|
192 |
+
# Check if there are any function calls to handle
|
193 |
+
if response.candidates[0].content.parts[0].function_call:
|
194 |
+
function_call = response.candidates[0].content.parts[0].function_call
|
195 |
+
function_name = function_call.name
|
196 |
+
function_args = json.loads(function_call.args)
|
197 |
+
|
198 |
+
# Dispatch to the appropriate function
|
199 |
+
if function_name == "google_search":
|
200 |
+
# Handle async function call
|
201 |
+
result = await google_search(
|
202 |
+
query=function_args['query'],
|
203 |
+
num_results=function_args.get('num_results', '3')
|
204 |
+
)
|
205 |
+
|
206 |
+
|
207 |
+
# Send the function result back to continue the conversation
|
208 |
+
response = chat.send_message_async(
|
209 |
+
part={
|
210 |
+
"function_response": {
|
211 |
+
"name": function_name,
|
212 |
+
"response": result
|
213 |
+
}
|
214 |
+
}
|
215 |
+
)
|
216 |
+
else:
|
217 |
+
# No more function calls, return the final response
|
218 |
+
return response
|
219 |
|
220 |
# Process message with retry logic
|
221 |
# async def process_message_with_retry(
|
app/services/search_engine.py
CHANGED
@@ -60,7 +60,8 @@ async def google_search(query: str, num_results: str = "3") -> Optional[List[Dic
|
|
60 |
if response.status_code == 200:
|
61 |
results = response.json()
|
62 |
items = results.get("items", [])
|
63 |
-
|
|
|
64 |
|
65 |
else:
|
66 |
logger.error(f"Google Search API error: {response.status_code} - {response.text}")
|
@@ -70,3 +71,18 @@ async def google_search(query: str, num_results: str = "3") -> Optional[List[Dic
|
|
70 |
logger.error("A network error occurred while performing the Google search.")
|
71 |
logger.error(f"Error details: {e}")
|
72 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
if response.status_code == 200:
|
61 |
results = response.json()
|
62 |
items = results.get("items", [])
|
63 |
+
|
64 |
+
return [{"title": item["title"], "link": item["link"], "snippet": item["snippet"]} for item in items]
|
65 |
|
66 |
else:
|
67 |
logger.error(f"Google Search API error: {response.status_code} - {response.text}")
|
|
|
71 |
logger.error("A network error occurred while performing the Google search.")
|
72 |
logger.error(f"Error details: {e}")
|
73 |
return None
|
74 |
+
|
75 |
+
def set_light_values(brightness: str, color_temp: str) -> Dict[str, str]:
|
76 |
+
"""Set the brightness and color temperature of a room light. (mock API).
|
77 |
+
|
78 |
+
Args:
|
79 |
+
brightness: Light level from 0 to 100. Zero is off and 100 is full brightness
|
80 |
+
color_temp: Color temperature of the light fixture, which can be `daylight`, `cool` or `warm`.
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
A dictionary containing the set brightness and color temperature.
|
84 |
+
"""
|
85 |
+
return {
|
86 |
+
"brightness": brightness,
|
87 |
+
"colorTemperature": color_temp
|
88 |
+
}
|
app/utils/load_env.py
CHANGED
@@ -20,8 +20,8 @@ CX_CODE = os.getenv("CX_CODE")
|
|
20 |
CUSTOM_SEARCH_API_KEY = os.getenv("CUSTOM_SEARCH_API_KEY")
|
21 |
|
22 |
# Debugging: Print the retrieved ACCESS_TOKEN (for development only)
|
23 |
-
if ENV == "development":
|
24 |
-
|
25 |
|
26 |
|
27 |
|
|
|
20 |
CUSTOM_SEARCH_API_KEY = os.getenv("CUSTOM_SEARCH_API_KEY")
|
21 |
|
22 |
# Debugging: Print the retrieved ACCESS_TOKEN (for development only)
|
23 |
+
# if ENV == "development":
|
24 |
+
# print(f"ACCESS_TOKEN loaded: {ACCESS_TOKEN}")
|
25 |
|
26 |
|
27 |
|
app/utils/system_prompt.py
CHANGED
@@ -22,4 +22,10 @@ Example Interactions:
|
|
22 |
If a user asks, “Are there any issues with the city government's policies?” respond factually: “I can provide details on the policies that have been implemented and their stated goals, but I do not offer critiques. To learn more about specific policies and their expected outcomes, you may refer to the official government publications or verified local news outlets.”
|
23 |
|
24 |
By following these guidelines, you will serve as a reliable, respectful, and informative resource for users looking to understand the latest happenings in Surabaya without engaging in criticism of the government.
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
"""
|
|
|
22 |
If a user asks, “Are there any issues with the city government's policies?” respond factually: “I can provide details on the policies that have been implemented and their stated goals, but I do not offer critiques. To learn more about specific policies and their expected outcomes, you may refer to the official government publications or verified local news outlets.”
|
23 |
|
24 |
By following these guidelines, you will serve as a reliable, respectful, and informative resource for users looking to understand the latest happenings in Surabaya without engaging in criticism of the government.
|
25 |
+
"""
|
26 |
+
|
27 |
+
agentic_prompt = """ You are a helpful assistant and have capabilities to search the web.
|
28 |
+
When you the links are given, you should summarize the content of the link and give a short summary.
|
29 |
+
You should also include the source of the link in the summary.
|
30 |
+
|
31 |
"""
|
app/utils/token_counter.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# token_counter.py
|
2 |
+
|
3 |
+
import os
|
4 |
+
import tiktoken
|
5 |
+
|
6 |
+
# Choose the encoding based on your model, e.g., 'cl100k_base' for OpenAI models
|
7 |
+
encoding = tiktoken.get_encoding("cl100k_base")
|
8 |
+
|
9 |
+
def count_tokens(text):
|
10 |
+
tokens = encoding.encode(text)
|
11 |
+
return len(tokens)
|
12 |
+
|
13 |
+
class TokenCounter:
|
14 |
+
def __init__(self):
|
15 |
+
self.total_tokens = 0
|
16 |
+
self.doc_tokens = {}
|
17 |
+
|
18 |
+
def add_document(self, doc_id, text):
|
19 |
+
num_tokens = count_tokens(text)
|
20 |
+
self.doc_tokens[doc_id] = num_tokens
|
21 |
+
self.total_tokens += num_tokens
|
22 |
+
|
23 |
+
def remove_document(self, doc_id):
|
24 |
+
if doc_id in self.doc_tokens:
|
25 |
+
self.total_tokens -= self.doc_tokens[doc_id]
|
26 |
+
del self.doc_tokens[doc_id]
|
27 |
+
|
28 |
+
def get_total_tokens(self):
|
29 |
+
return self.total_tokens
|
app/utils/tool_call_extractor.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import json
|
3 |
+
from typing import List, Dict, Any, Optional
|
4 |
+
|
5 |
+
class ToolCallExtractor:
|
6 |
+
def __init__(self):
|
7 |
+
# Existing regex patterns (retain if needed for other formats)
|
8 |
+
self.complete_pattern = re.compile(r'<\|python_tag\|>(.*?)<\|eom_id\|>', re.DOTALL)
|
9 |
+
self.partial_pattern = re.compile(r'(.*?)<\|eom_id\|>', re.DOTALL)
|
10 |
+
|
11 |
+
def _extract_function_args(self, args) -> Dict[str, Any]:
|
12 |
+
"""
|
13 |
+
Flatten the nested function args structure for Google AI protobuf types.
|
14 |
+
"""
|
15 |
+
flattened_args = {}
|
16 |
+
|
17 |
+
try:
|
18 |
+
# Explicitly check for fields
|
19 |
+
if hasattr(args, 'fields'):
|
20 |
+
# Iterate through fields using to_dict() to convert protobuf to dict
|
21 |
+
for field in args.fields:
|
22 |
+
key = field.key
|
23 |
+
value = field.value
|
24 |
+
|
25 |
+
# Additional debugging
|
26 |
+
print(f"Field key: {key}")
|
27 |
+
print(f"Field value type: {type(value)}")
|
28 |
+
print(f"Field value: {value}")
|
29 |
+
|
30 |
+
# Extract string value
|
31 |
+
if hasattr(value, 'string_value'):
|
32 |
+
flattened_args[key] = value.string_value
|
33 |
+
print(f"Extracted string value: {value.string_value}")
|
34 |
+
elif hasattr(value, 'number_value'):
|
35 |
+
flattened_args[key] = value.number_value
|
36 |
+
elif hasattr(value, 'bool_value') and value.bool_value is not None:
|
37 |
+
flattened_args[key] = value.bool_value
|
38 |
+
|
39 |
+
# Added additional debug information
|
40 |
+
print(f"Final flattened args: {flattened_args}")
|
41 |
+
|
42 |
+
except Exception as e:
|
43 |
+
print(f"Error extracting function args: {e}")
|
44 |
+
|
45 |
+
return flattened_args
|
46 |
+
|
47 |
+
def extract_tool_calls(self, input_string: str) -> List[Dict[str, Any]]:
|
48 |
+
"""
|
49 |
+
Extract tool calls from input string, handling various inconsistent formats.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
input_string (str): The input string containing tool calls.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
list: A list of dictionaries representing the parsed tool calls.
|
56 |
+
"""
|
57 |
+
tool_calls = []
|
58 |
+
|
59 |
+
# Existing tag-based extraction (retain if needed)
|
60 |
+
complete_matches = self.complete_pattern.findall(input_string)
|
61 |
+
if complete_matches:
|
62 |
+
for match in complete_matches:
|
63 |
+
tool_calls.extend(self._extract_json_objects(match))
|
64 |
+
return tool_calls
|
65 |
+
|
66 |
+
partial_matches = self.partial_pattern.findall(input_string)
|
67 |
+
if partial_matches:
|
68 |
+
for match in partial_matches:
|
69 |
+
tool_calls.extend(self._extract_json_objects(match))
|
70 |
+
return tool_calls
|
71 |
+
|
72 |
+
# Fallback: Attempt to parse the entire string
|
73 |
+
tool_calls.extend(self._extract_json_objects(input_string))
|
74 |
+
|
75 |
+
return tool_calls
|
76 |
+
|
77 |
+
def _extract_json_objects(self, text: str) -> List[Dict[str, Any]]:
|
78 |
+
"""
|
79 |
+
Extract and parse multiple JSON objects from a string.
|
80 |
+
"""
|
81 |
+
json_objects = []
|
82 |
+
potential_jsons = text.split(';')
|
83 |
+
|
84 |
+
for json_str in potential_jsons:
|
85 |
+
parsed_obj = self._clean_and_parse_json(json_str)
|
86 |
+
if parsed_obj:
|
87 |
+
json_objects.append(parsed_obj)
|
88 |
+
|
89 |
+
return json_objects
|
90 |
+
|
91 |
+
def _clean_and_parse_json(self, json_str: str) -> Optional[Dict[str, Any]]:
|
92 |
+
"""
|
93 |
+
Clean and parse a JSON string, handling common formatting issues.
|
94 |
+
"""
|
95 |
+
try:
|
96 |
+
json_str = json_str.strip()
|
97 |
+
if json_str.startswith('{') or json_str.startswith('['):
|
98 |
+
return json.loads(json_str)
|
99 |
+
return None
|
100 |
+
except json.JSONDecodeError:
|
101 |
+
return None
|
102 |
+
|
103 |
+
def validate_tool_call(self, tool_call: Dict[str, Any]) -> bool:
|
104 |
+
"""
|
105 |
+
Validate if a tool call has the required fields.
|
106 |
+
"""
|
107 |
+
return (
|
108 |
+
isinstance(tool_call, dict) and
|
109 |
+
'name' in tool_call and
|
110 |
+
isinstance(tool_call['name'], str)
|
111 |
+
)
|
112 |
+
|
113 |
+
def extract_function_call(self, response_parts: List[Any]) -> Dict[str, Any]:
|
114 |
+
"""
|
115 |
+
Extract function call details from the response parts.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
response_parts (list): The list of response parts from the chat model.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
dict: A dictionary containing the function name and flattened arguments.
|
122 |
+
"""
|
123 |
+
for part in response_parts:
|
124 |
+
# Debug print
|
125 |
+
print(f"Examining part: {part}")
|
126 |
+
print(f"Part type: {type(part)}")
|
127 |
+
|
128 |
+
# Check for function_call attribute
|
129 |
+
if hasattr(part, 'function_call') and part.function_call:
|
130 |
+
function_call = part.function_call
|
131 |
+
|
132 |
+
# Debug print
|
133 |
+
print(f"Function call: {function_call}")
|
134 |
+
print(f"Function call type: {type(function_call)}")
|
135 |
+
print(f"Function args: {function_call.args}")
|
136 |
+
|
137 |
+
# Extract function name
|
138 |
+
function_name = getattr(function_call, 'name', None)
|
139 |
+
if not function_name:
|
140 |
+
continue # Skip if function name is missing
|
141 |
+
|
142 |
+
# Extract function arguments
|
143 |
+
function_args = self._extract_function_args(function_call.args)
|
144 |
+
|
145 |
+
return {
|
146 |
+
"name": function_name,
|
147 |
+
"args": function_args
|
148 |
+
}
|
149 |
+
return {}
|
document_logs_2024-12-20.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
2024-12-20 12:49:01,713 - INFO - ID: ea205193-4582-44bc-ab71-80176aac7aef, Snippet: [SSW](https://sswalfa.surabaya.go.id/home) [
|
21 |
+
|
22 |
+
# Set up logging
|
23 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
def tool_config_from_mode(mode: str, fns: Iterable[str] = ()):
|
27 |
+
"""
|
28 |
+
Create a tool config with the specified function calling mode.
|
29 |
+
"""
|
30 |
+
return content_types.to_tool_config(
|
31 |
+
{"function_calling_config": {"mode": mode, "allowed_function_names": fns}}
|
32 |
+
)
|
33 |
+
|
34 |
+
def transform_result_to_response(results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
35 |
+
"""
|
36 |
+
Transform a list of result objects into a structured response dictionary.
|
37 |
+
"""
|
38 |
+
response = {}
|
39 |
+
for res in results:
|
40 |
+
if res.get("status") == "success":
|
41 |
+
function_name = res.get("function")
|
42 |
+
function_result = res.get("result")
|
43 |
+
response[function_name] = function_result
|
44 |
+
else:
|
45 |
+
# Handle individual failures if necessary
|
46 |
+
response[res.get("function", "unknown_function")] = {
|
47 |
+
"error": "Function execution failed."
|
48 |
+
}
|
49 |
+
return response
|
50 |
+
|
51 |
+
async def process_tool_calls(input_string: str) -> List[Dict[str, Any]]:
|
52 |
+
"""
|
53 |
+
Processes all tool calls extracted from the input string and executes them.
|
54 |
+
"""
|
55 |
+
tool_calls = ToolCallParser.extract_tool_calls(input_string)
|
56 |
+
logger.info(f"Extracted tool_calls: {tool_calls}")
|
57 |
+
results = []
|
58 |
+
for tool_call in tool_calls:
|
59 |
+
result = await FunctionExecutor.call_function(tool_call)
|
60 |
+
results.append(result)
|
61 |
+
return results
|
62 |
+
|
63 |
+
async def main():
|
64 |
+
# Define available functions and tool configuration
|
65 |
+
available_functions = ["google_search", "set_light_values"]
|
66 |
+
config = tool_config_from_mode("any", fns=available_functions)
|
67 |
+
|
68 |
+
# Define chat history
|
69 |
+
history = [{"role": "user", "parts": "This is the chat history so far"}]
|
70 |
+
|
71 |
+
# Configure the Gemini API
|
72 |
+
genai.configure(api_key=GEMNI_API)
|
73 |
+
model = genai.GenerativeModel(
|
74 |
+
"gemini-1.5-pro-002",
|
75 |
+
system_instruction=agentic_prompt,
|
76 |
+
tools=[google_search, set_light_values]
|
77 |
+
)
|
78 |
+
|
79 |
+
# Start chat with history
|
80 |
+
chat = model.start_chat(history=history)
|
81 |
+
|
82 |
+
# Send the user's message and await the response
|
83 |
+
try:
|
84 |
+
response = chat.send_message(
|
85 |
+
"find the cheapest flight price from Medan to Jakarta on 1st January 2025",
|
86 |
+
tool_config=config
|
87 |
+
)
|
88 |
+
except Exception as e:
|
89 |
+
logger.error(f"Error sending message: {e}")
|
90 |
+
return
|
91 |
+
|
92 |
+
# Ensure that response.parts exists and is iterable
|
93 |
+
if not hasattr(response, 'parts') or not isinstance(response.parts, Iterable):
|
94 |
+
logger.error("Invalid response format: 'parts' attribute is missing or not iterable.")
|
95 |
+
return
|
96 |
+
|
97 |
+
# Convert response parts to a single input string
|
98 |
+
input_string = "\n".join(str(part) for part in response.parts)
|
99 |
+
logger.info(f"Input string for tool processing: {input_string}")
|
100 |
+
|
101 |
+
# Process tool calls
|
102 |
+
try:
|
103 |
+
results = await process_tool_calls(input_string)
|
104 |
+
except Exception as e:
|
105 |
+
logger.error(f"Error processing tool calls: {e}")
|
106 |
+
return
|
107 |
+
|
108 |
+
# Log and print the results
|
109 |
+
logger.info("Results from tool calls:")
|
110 |
+
for result in results:
|
111 |
+
logger.info(json.dumps(result, indent=4))
|
112 |
+
print(json.dumps(result, indent=4))
|
113 |
+
|
114 |
+
# Transform the results into the desired response format
|
115 |
+
responses = transform_result_to_response(results)
|
116 |
+
|
117 |
+
# Build the response parts for the chat
|
118 |
+
try:
|
119 |
+
response_parts = [
|
120 |
+
genai.protos.Part(
|
121 |
+
function_response=genai.protos.FunctionResponse(
|
122 |
+
name=fn,
|
123 |
+
response={"result": val}
|
124 |
+
)
|
125 |
+
)
|
126 |
+
for fn, val in responses.items()
|
127 |
+
]
|
128 |
+
except Exception as e:
|
129 |
+
logger.error(f"Error building response parts: {e}")
|
130 |
+
return
|
131 |
+
|
132 |
+
# Send the function responses back to the chat
|
133 |
+
try:
|
134 |
+
final_response = chat.send_message(response_parts)
|
135 |
+
print(final_response.text)
|
136 |
+
except Exception as e:
|
137 |
+
logger.error(f"Error sending final response: {e}")
|
138 |
+
|
139 |
+
if __name__ == "__main__":
|
140 |
+
asyncio.run(main())
|