ariansyahdedy commited on
Commit
8d2f9d4
·
1 Parent(s): 44e73ac

Add prompt edit and api key config

Browse files
.gitignore CHANGED
@@ -2,5 +2,8 @@
2
  __pycache__
3
  .env
4
  user_media/
 
 
 
5
 
6
 
 
2
  __pycache__
3
  .env
4
  user_media/
5
+ toolkits/
6
+ test*.py
7
+
8
 
9
 
app/api/api_file.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, Request, Query, status
2
+ from fastapi.responses import StreamingResponse
3
+ import os
4
+ import logging
5
+ import uuid
6
+ from datetime import datetime
7
+
8
+ from pydantic import BaseModel, Field
9
+ from typing import Optional, List, Any
10
+ from urllib.parse import urlparse
11
+ import shutil
12
+ # from app.wrapper.llm_wrapper import *
13
+ from app.crud.process_file import load_file_with_markitdown, process_uploaded_file
14
+
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def is_url(path: str) -> bool:
21
+ """
22
+ Determines if the given path is a URL.
23
+
24
+ Args:
25
+ path (str): The path or URL to check.
26
+
27
+ Returns:
28
+ bool: True if it's a URL, False otherwise.
29
+ """
30
+ try:
31
+ result = urlparse(path)
32
+ return all([result.scheme, result.netloc])
33
+ except Exception:
34
+ return False
35
+
36
+
37
+ file_router = APIRouter()
38
+
39
+ # Configure logging to file with date-based filenames
40
+ log_filename = f"document_logs_{datetime.now().strftime('%Y-%m-%d')}.txt"
41
+ file_handler = logging.FileHandler(log_filename)
42
+ formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
43
+ file_handler.setFormatter(formatter)
44
+
45
+ # Create a logger for document processing
46
+ doc_logger = logging.getLogger('document_logger')
47
+ doc_logger.setLevel(logging.INFO)
48
+ doc_logger.addHandler(file_handler)
49
+
50
+ # Also configure the general logger if not already configured
51
+ logging.basicConfig(level=logging.INFO)
52
+ logger = logging.getLogger(__name__)
53
+
54
+ from app.search.rag_pipeline import RAGSystem
55
+ from sentence_transformers import SentenceTransformer
56
+
57
+
58
+ @file_router.post("/load_file_with_markdown/")
59
+ async def load_file_with_markdown(request: Request, filepaths: List[str]):
60
+ try:
61
+ # Ensure RAG system is initialized
62
+ try:
63
+ rag_system = request.app.state.rag_system
64
+ if rag_system is None:
65
+ raise AttributeError("RAG system is not initialized in app state")
66
+ except AttributeError:
67
+ logger.error("RAG system is not initialized in app state")
68
+ raise HTTPException(status_code=500, detail="RAG system not initialized in app state")
69
+
70
+
71
+ processed_files = []
72
+ pages = []
73
+
74
+ # Process each file path or URL
75
+ for path in filepaths:
76
+ if is_url(path):
77
+ logger.info(f"Processing URL: {path}")
78
+ try:
79
+ # Generate a unique UUID for the document
80
+ doc_id = str(uuid.uuid4())
81
+
82
+ # Process the URL
83
+ document = await process_uploaded_file(id=doc_id, file_path=path, rag_system=rag_system)
84
+
85
+ # Append the document details to pages
86
+ pages.append({
87
+ "metadata": {"title": document.title},
88
+ "page_content": document.text_content,
89
+ })
90
+
91
+ logger.info(f"Successfully processed URL: {path} with ID: {doc_id}")
92
+
93
+ # Log the ID and a 100-character snippet of the document
94
+ snippet = document.text_content[:100].replace('\n', ' ').replace('\r', ' ')
95
+ # Ensure 'doc_logger' is defined; if not, use 'logger' or define 'doc_logger'
96
+ doc_logger.info(f"ID: {doc_id}_{document.title}, Snippet: {snippet}")
97
+
98
+
99
+ except Exception as e:
100
+ logger.error(f"Error processing URL {path}: {str(e)}")
101
+ processed_files.append({"path": path, "status": "error", "message": str(e)})
102
+
103
+ else:
104
+ logger.info(f"Processing local file: {path}")
105
+ if os.path.exists(path):
106
+ try:
107
+ # Generate a unique UUID for the document
108
+ doc_id = str(uuid.uuid4())
109
+
110
+ # Process the local file
111
+ document = await process_uploaded_file(id=doc_id, file_path=path, rag_system=rag_system)
112
+
113
+ # Append the document details to pages
114
+ pages.append({
115
+ "metadata": {"title": document.title},
116
+ "page_content": document.text_content,
117
+ })
118
+
119
+ logger.info(f"Successfully processed file: {path} with ID: {doc_id}")
120
+
121
+ # Log the ID and a 100-character snippet of the document
122
+ snippet = document.text_content[:100].replace('\n', ' ').replace('\r', ' ')
123
+ # Ensure 'doc_logger' is defined; if not, use 'logger' or define 'doc_logger'
124
+ logger.info(f"ID: {doc_id}, Snippet: {snippet}")
125
+
126
+ except Exception as e:
127
+ logger.error(f"Error processing file {path}: {str(e)}")
128
+ processed_files.append({"path": path, "status": "error", "message": str(e)})
129
+ else:
130
+ logger.error(f"File path does not exist: {path}")
131
+ processed_files.append({"path": path, "status": "not found"})
132
+
133
+ # Get total tokens from RAG system
134
+ total_tokens = rag_system.get_total_tokens() if hasattr(rag_system, "get_total_tokens") else 0
135
+
136
+ return {
137
+ "message": "File processing completed",
138
+ "total_tokens": total_tokens,
139
+ "document_count": len(filepaths),
140
+ "pages": pages,
141
+ "errors": processed_files, # Include details about files that couldn't be processed
142
+ }
143
+
144
+ except Exception as e:
145
+ logger.exception("Unexpected error during file processing")
146
+ raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
147
+
148
+ async def load_file_with_markdown_function(filepaths: List[str],
149
+ rag_system: Any):
150
+ try:
151
+ # Ensure RAG system is initialized
152
+ try:
153
+ rag_system = rag_system
154
+ except AttributeError:
155
+ logger.error("RAG system is not initialized in app state")
156
+ raise HTTPException(status_code=500, detail="RAG system not initialized in app state")
157
+
158
+
159
+ processed_files = []
160
+ pages = []
161
+
162
+ # Process each file path or URL
163
+ for path in filepaths:
164
+ if is_url(path):
165
+ logger.info(f"Processing URL: {path}")
166
+ try:
167
+ # Generate a unique UUID for the document
168
+ doc_id = str(uuid.uuid4())
169
+
170
+ # Process the URL
171
+ document = await process_uploaded_file(id=doc_id, file_path=path, rag_system=rag_system)
172
+
173
+ # Append the document details to pages
174
+ pages.append({
175
+ "metadata": {"title": document.title},
176
+ "page_content": document.text_content,
177
+ })
178
+
179
+ logger.info(f"Successfully processed URL: {path} with ID: {doc_id}")
180
+
181
+ # Log the ID and a 100-character snippet of the document
182
+ snippet = document.text_content[:100].replace('\n', ' ').replace('\r', ' ')
183
+ # Ensure 'doc_logger' is defined; if not, use 'logger' or define 'doc_logger'
184
+ doc_logger(f"ID: {doc_id}, Snippet: {snippet}")
185
+ logger.info(f"ID: {doc_id}, Snippet: {snippet}")
186
+
187
+ except Exception as e:
188
+ logger.error(f"Error processing URL {path}: {str(e)}")
189
+ processed_files.append({"path": path, "status": "error", "message": str(e)})
190
+
191
+ else:
192
+ logger.info(f"Processing local file: {path}")
193
+ if os.path.exists(path):
194
+ try:
195
+ # Generate a unique UUID for the document
196
+ doc_id = str(uuid.uuid4())
197
+
198
+ # Process the local file
199
+ document = await process_uploaded_file(id=doc_id, file_path=path, rag_system=rag_system)
200
+
201
+ # Append the document details to pages
202
+ pages.append({
203
+ "metadata": {"title": document.title},
204
+ "page_content": document.text_content,
205
+ })
206
+
207
+ logger.info(f"Successfully processed file: {path} with ID: {doc_id}")
208
+
209
+ # Log the ID and a 100-character snippet of the document
210
+ snippet = document.text_content[:100].replace('\n', ' ').replace('\r', ' ')
211
+ # Ensure 'doc_logger' is defined; if not, use 'logger' or define 'doc_logger'
212
+ logger.info(f"ID: {doc_id}, Snippet: {snippet}")
213
+
214
+ except Exception as e:
215
+ logger.error(f"Error processing file {path}: {str(e)}")
216
+ processed_files.append({"path": path, "status": "error", "message": str(e)})
217
+ else:
218
+ logger.error(f"File path does not exist: {path}")
219
+ processed_files.append({"path": path, "status": "not found"})
220
+
221
+ # Get total tokens from RAG system
222
+ total_tokens = rag_system.get_total_tokens() if hasattr(rag_system, "get_total_tokens") else 0
223
+
224
+ return {
225
+ "message": "File processing completed",
226
+ "total_tokens": total_tokens,
227
+ "document_count": len(filepaths),
228
+ "pages": pages,
229
+ "errors": processed_files, # Include details about files that couldn't be processed
230
+ }
231
+
232
+ except Exception as e:
233
+ logger.exception("Unexpected error during file processing")
234
+ raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
235
+
236
+ @file_router.get("/document_exists/{doc_id}", status_code=status.HTTP_200_OK)
237
+ async def document_exists(request: Request, doc_id: str):
238
+ try:
239
+ rag_system = request.app.state.rag_system
240
+ except AttributeError:
241
+ logger.error("RAG system is not initialized in app state")
242
+ raise HTTPException(status_code=500, detail="RAG system not initialized in app state")
243
+
244
+ exists = doc_id in rag_system.doc_ids
245
+ return {"document_id": doc_id, "exists": exists}
246
+
247
+ @file_router.delete("/delete_document/{doc_id}", status_code=status.HTTP_200_OK)
248
+ async def delete_document(request: Request, doc_id: str):
249
+ try:
250
+ rag_system = request.app.state.rag_system
251
+ except AttributeError:
252
+ logger.error("RAG system is not initialized in app state")
253
+ raise HTTPException(status_code=500, detail="RAG system not initialized in app state")
254
+
255
+ try:
256
+ rag_system.delete_document(doc_id)
257
+ logger.info(f"Deleted document with ID: {doc_id}")
258
+ return {"message": f"Document with ID {doc_id} has been deleted."}
259
+ except Exception as e:
260
+ logger.error(f"Error deleting document with ID {doc_id}: {str(e)}")
261
+ raise HTTPException(status_code=500, detail=f"Failed to delete document: {str(e)}")
app/api/api_prompt.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, APIRouter
2
+ from pydantic import BaseModel
3
+ from app.utils.system_prompt import system_prompt, agentic_prompt
4
+
5
+ prompt_router = APIRouter()
6
+
7
+ # Define a model for the prompts
8
+ class Prompt(BaseModel):
9
+ system_prompt: str = None
10
+ agentic_prompt: str = None
11
+
12
+ # API endpoint to get the current prompts
13
+ @prompt_router.get("/prompts")
14
+ def get_prompts():
15
+ return {
16
+ "system_prompt": system_prompt,
17
+ "agentic_prompt": agentic_prompt,
18
+ }
19
+
20
+ # API endpoint to update the prompts
21
+ @prompt_router.put("/prompts")
22
+ def update_prompts(prompts: Prompt):
23
+ global system_prompt, agentic_prompt
24
+ if prompts.system_prompt is not None:
25
+ system_prompt = prompts.system_prompt
26
+ if prompts.agentic_prompt is not None:
27
+ agentic_prompt = prompts.agentic_prompt
28
+ return {
29
+ "message": "Prompts updated successfully",
30
+ "system_prompt": system_prompt,
31
+ "agentic_prompt": agentic_prompt,
32
+ }
app/app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request, status
2
+ from fastapi.responses import JSONResponse
3
+ from fastapi.responses import Response
4
+ from fastapi.exceptions import HTTPException
5
+ from slowapi import Limiter, _rate_limit_exceeded_handler
6
+ from slowapi.errors import RateLimitExceeded
7
+ from slowapi.util import get_remote_address
8
+ from slowapi.middleware import SlowAPIMiddleware
9
+ from typing import Dict, List
10
+ from prometheus_client import Counter, Histogram, start_http_server
11
+ from pydantic import BaseModel, ValidationError
12
+ from app.services.message import generate_reply, send_reply
13
+ import logging
14
+ from datetime import datetime
15
+ from sentence_transformers import SentenceTransformer
16
+
17
+ from contextlib import asynccontextmanager
18
+ # from app.db.database import create_indexes, init_db
19
+ from app.services.webhook_handler import verify_webhook
20
+ from app.handlers.message_handler import MessageHandler
21
+ from app.handlers.webhook_handler import WebhookHandler
22
+ from app.handlers.media_handler import WhatsAppMediaHandler
23
+ from app.services.cache import MessageCache
24
+ from app.services.chat_manager import ChatManager
25
+ from app.api.api_file import file_router
26
+ from app.utils.load_env import ACCESS_TOKEN
27
+ from app.search.rag_pipeline import RAGSystem
28
+
29
+ # Configure logging
30
+ logging.basicConfig(
31
+ level=logging.INFO,
32
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
33
+ )
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+ # Initialize handlers at startup
38
+ message_handler = None
39
+ webhook_handler = None
40
+
41
+
42
+ async def setup_message_handler():
43
+ logger = logging.getLogger(__name__)
44
+ message_cache = MessageCache()
45
+ chat_manager = ChatManager()
46
+ media_handler = WhatsAppMediaHandler()
47
+
48
+ return MessageHandler(
49
+ message_cache=message_cache,
50
+ chat_manager=chat_manager,
51
+ media_handler=media_handler,
52
+ logger=logger
53
+ )
54
+
55
+ async def setup_rag_system():
56
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2') # Replace with your model if different
57
+ rag_system = RAGSystem(embedding_model)
58
+
59
+ return rag_system
60
+
61
+ # Initialize FastAPI app
62
+ @asynccontextmanager
63
+ async def lifespan(app: FastAPI):
64
+
65
+
66
+ try:
67
+ # await init_db()
68
+
69
+ logger.info("Connected to the MongoDB database!")
70
+
71
+ rag_system = await setup_rag_system()
72
+
73
+ app.state.rag_system = rag_system
74
+
75
+ global message_handler, webhook_handler
76
+ message_handler = await setup_message_handler()
77
+ webhook_handler = WebhookHandler(message_handler)
78
+ # collections = app.database.list_collection_names()
79
+ # print(f"Collections in {db_name}: {collections}")
80
+ yield
81
+ except Exception as e:
82
+ logger.error(e)
83
+
84
+ # Initialize Limiter and Prometheus Metrics
85
+ limiter = Limiter(key_func=get_remote_address)
86
+ app = FastAPI(lifespan=lifespan)
87
+ app.state.limiter = limiter
88
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
89
+
90
+ # Add SlowAPI Middleware
91
+ app.add_middleware(SlowAPIMiddleware)
92
+
93
+ # app.include_router(users.router, prefix="/users", tags=["Users"])
94
+ app.include_router(file_router, prefix="/file_load", tags=["File Load"])
95
+
96
+ # Prometheus metrics
97
+ webhook_requests = Counter('webhook_requests_total', 'Total webhook requests')
98
+ webhook_processing_time = Histogram('webhook_processing_seconds', 'Time spent processing webhook')
99
+
100
+ # Pseudocode: You may have a database or environment variable to validate API keys
101
+ VALID_API_KEYS = {"!@#$%^"}
102
+ class WebhookPayload(BaseModel):
103
+ entry: List[Dict]
104
+
105
+ @app.post("/api/v1/messages")
106
+ @limiter.limit("20/minute")
107
+ async def process_message(request: Request):
108
+ try:
109
+ # Validate developer’s API key
110
+ api_key = request.headers.get("Authorization")
111
+ if not api_key or not api_key.startswith("Bearer "):
112
+ raise HTTPException(status_code=401, detail="Missing or invalid API key")
113
+
114
+ api_key_value = api_key.replace("Bearer ", "")
115
+ if api_key_value not in VALID_API_KEYS:
116
+ raise HTTPException(status_code=403, detail="Forbidden")
117
+
118
+ payload = await request.json()
119
+
120
+ # Extract needed credentials from query params or request body
121
+ # e.g., whatsapp_token, verify_token, llm_api_key, llm_model
122
+ whatsapp_token = request.query_params.get("whatsapp_token")
123
+ whatsapp_url = request.query_params.get("whatsapp_url")
124
+ gemini_api = request.query_params.get("gemini_api")
125
+ llm_model = request.query_params.get("cx_code")
126
+
127
+ print(f"payload: {payload}")
128
+ response = await webhook_handler.process_webhook(
129
+ payload=payload,
130
+ whatsapp_token=whatsapp_token,
131
+ whatsapp_url=whatsapp_url,
132
+ gemini_api=gemini_api,
133
+ )
134
+
135
+ return JSONResponse(
136
+ content=response.__dict__,
137
+ status_code=status.HTTP_200_OK
138
+ )
139
+
140
+ except ValidationError as ve:
141
+ logger.error(f"Validation error: {ve}")
142
+ return JSONResponse(
143
+ content={"status": "error", "detail": ve.errors()},
144
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY
145
+ )
146
+ except Exception as e:
147
+ logger.error(f"Unexpected error: {str(e)}")
148
+ return JSONResponse(
149
+ content={"status": "error", "detail": str(e)},
150
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
151
+ )
152
+
153
+ app.get("/webhook")(verify_webhook)
app/crud/process_file.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/crud.py
2
+ from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, CSVLoader, UnstructuredExcelLoader
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from sqlalchemy.future import select
5
+ from sqlalchemy.ext.asyncio import AsyncSession
6
+ from markitdown import MarkItDown
7
+ import os
8
+ import logging
9
+
10
+
11
+
12
+ from typing import List, Optional
13
+ # from app.db.models.docs import *
14
+ # from app.schemas.schemas import DocumentCreate, DocumentUpdate
15
+
16
+ # Configure logging
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+ async def load_file_with_markitdown(file_path:str, llm_client:str=None, model:str=None):
21
+
22
+ if llm_client and model:
23
+ markitdown = MarkItDown(llm_client, model)
24
+ documents = markitdown.convert(file_path)
25
+ else:
26
+ markitdown = MarkItDown()
27
+ documents = markitdown.convert(file_path)
28
+
29
+ return documents
30
+
31
+
32
+
33
+ async def load_pdf_with_langchain(file_path):
34
+ """
35
+ Loads and extracts text from a PDF file using LangChain's PyPDFLoader.
36
+
37
+ Parameters:
38
+ file_path (str): Path to the PDF file.
39
+
40
+ Returns:
41
+ List[Document]: A list of LangChain Document objects with metadata.
42
+ """
43
+
44
+ loader = PyPDFLoader(file_path, extract_images=True)
45
+
46
+ documents = loader.load()
47
+
48
+ return documents # Returns a list of Document objects
49
+
50
+ async def load_file_with_langchain(file_path: str):
51
+ """
52
+ Loads and extracts text from a PDF or DOCX file using LangChain's appropriate loader.
53
+
54
+ Parameters:
55
+ file_path (str): Path to the file (PDF or DOCX).
56
+
57
+ Returns:
58
+ List[Document]: A list of LangChain Document objects with metadata.
59
+ """
60
+ # Determine the file extension
61
+ _, file_extension = os.path.splitext(file_path)
62
+
63
+ # Choose the loader based on file extension
64
+ if file_extension.lower() == '.pdf':
65
+ loader = PyPDFLoader(file_path)
66
+ elif file_extension.lower() == '.docx':
67
+ loader = Docx2txtLoader(file_path)
68
+ elif file_extension.lower() == '.csv':
69
+ loader = CSVLoader(file_path)
70
+ elif file_extension.lower() == '.xlsx':
71
+ loader = UnstructuredExcelLoader(file_path)
72
+ else:
73
+ raise ValueError("Unsupported file format. Please provide a PDF or DOCX file.")
74
+
75
+ # Load the documents
76
+ documents = loader.load()
77
+
78
+ return documents
79
+
80
+ async def split_documents(documents, chunk_size=10000, chunk_overlap=1000):
81
+ """
82
+ Splits documents into smaller chunks with overlap.
83
+
84
+ Parameters:
85
+ documents (List[Document]): List of LangChain Document objects.
86
+ chunk_size (int): The maximum size of each chunk.
87
+ chunk_overlap (int): The number of characters to overlap between chunks.
88
+
89
+ Returns:
90
+ List[Document]: List of chunked Document objects.
91
+ """
92
+ text_splitter = RecursiveCharacterTextSplitter(
93
+ chunk_size=chunk_size,
94
+ chunk_overlap=chunk_overlap,
95
+ )
96
+ split_docs = text_splitter.split_documents(documents)
97
+ return split_docs
98
+
99
+ async def process_uploaded_file(
100
+ id, file_path,
101
+ rag_system=None,
102
+ llm_client=None,
103
+ llm_model=None
104
+ ):
105
+
106
+ try:
107
+ # Load the document using LangChain
108
+ documents = await load_file_with_markitdown(file_path, llm_client=llm_client, model=llm_model)
109
+ logger.info(f"Loaded document: {file_path}")
110
+
111
+ # Concatenate all pages to get the full document text for context generation
112
+ # whole_document_content = "\n".join([doc.page_content for doc in documents])
113
+
114
+ except Exception as e:
115
+ logger.error(f"Failed to load document {file_path}: {e}")
116
+ raise RuntimeError(f"Error loading document: {file_path}") from e
117
+
118
+ # # Generate context for each chunk if llm is provided
119
+ # if llm:
120
+ # for doc in split_docs:
121
+ # try:
122
+ # context = await llm.generate_context(doc, whole_document_content=whole_document_content)
123
+ # # Add context to the beginning of the page content
124
+ # doc.page_content = f"{context.replace('<|eot_id|>', '')}\n\n{doc.page_content}"
125
+ # logger.info(f"Context generated and added for chunk {split_docs.index(doc)}")
126
+ # except Exception as e:
127
+ # logger.error(f"Failed to generate context for chunk {split_docs.index(doc)}: {e}")
128
+ # raise RuntimeError(f"Error generating context for chunk {split_docs.index(doc)}") from e
129
+
130
+ # Add to RAG system if rag_system is provided and load_only is False
131
+ if rag_system:
132
+ try:
133
+ rag_system.add_document(doc_id = f"{id}_{documents.title}", text = documents.text_content)
134
+
135
+ print(f"doc_id: {id}_{documents.title}")
136
+ print(f"content: {documents.text_content}")
137
+
138
+ # print(f"New Page Content: {doc.page_content}")
139
+ logger.info(f"Document chunks successfully added to RAG system for file {file_path}")
140
+
141
+ except Exception as e:
142
+ logger.error(f"Failed to add document chunks to RAG system for {file_path}: {e}")
143
+ raise RuntimeError(f"Error adding document to RAG system: {file_path}") from e
144
+ else:
145
+ logger.info(f"Loaded document {file_path}, but not added to RAG system")
146
+
147
+ return documents
148
+
149
+
150
+
151
+
152
+
app/handlers/media_handler.py CHANGED
@@ -7,9 +7,9 @@ logger = logging.getLogger(__name__)
7
 
8
  class MediaHandler(ABC):
9
  @abstractmethod
10
- async def download(self, media_id: str, access_token: str, file_path: str) -> str:
11
  pass
12
 
13
  class WhatsAppMediaHandler(MediaHandler):
14
- async def download(self, media_id: str, access_token: str, file_path: str) -> str:
15
- return await download_whatsapp_media(media_id, access_token, file_path)
 
7
 
8
  class MediaHandler(ABC):
9
  @abstractmethod
10
+ async def download(self, media_id: str, whatsapp_token: str, file_path: str) -> str:
11
  pass
12
 
13
  class WhatsAppMediaHandler(MediaHandler):
14
+ async def download(self, media_id: str, whatsapp_token: str, file_path: str) -> str:
15
+ return await download_whatsapp_media(media_id, whatsapp_token, file_path)
app/handlers/message_handler.py CHANGED
@@ -26,7 +26,7 @@ class MessageHandler:
26
  self.media_handler = media_handler
27
  self.logger = logger
28
 
29
- async def handle(self, raw_message: dict, access_token: str) -> dict:
30
  try:
31
  # Parse message
32
  message = MessageParser.parse(raw_message)
@@ -36,7 +36,7 @@ class MessageHandler:
36
  return {"status": "duplicate", "message_id": message.id}
37
 
38
  # Download media
39
- media_paths = await self._process_media(message, access_token)
40
 
41
  self.chat_manager.initialize_chat(message.sender_id)
42
 
@@ -45,7 +45,9 @@ class MessageHandler:
45
  result = await process_message_with_llm(
46
  message.sender_id,
47
  message.content,
48
- self.chat_manager.get_chat_history(message.sender_id),
 
 
49
  **media_paths
50
  )
51
 
@@ -60,7 +62,7 @@ class MessageHandler:
60
  except Exception as e:
61
  return {"status": "error", "message_id": raw_message.get("id"), "error": str(e)}
62
 
63
- async def _process_media(self, message: Message, access_token: str) -> Dict[str, Optional[str]]:
64
  media_paths = {
65
  "image_file_path": None,
66
  "doc_path": None,
@@ -74,7 +76,7 @@ class MessageHandler:
74
  self.logger.info(f"Processing {media_type.value}: {content.file_path}")
75
  file_path = await self.media_handler.download(
76
  content.id,
77
- access_token,
78
  content.file_path
79
  )
80
  self.logger.info(f"{media_type.value} file_path: {file_path}")
 
26
  self.media_handler = media_handler
27
  self.logger = logger
28
 
29
+ async def handle(self, raw_message: dict, whatsapp_token: str, whatsapp_url:str,gemini_api:str) -> dict:
30
  try:
31
  # Parse message
32
  message = MessageParser.parse(raw_message)
 
36
  return {"status": "duplicate", "message_id": message.id}
37
 
38
  # Download media
39
+ media_paths = await self._process_media(message, whatsapp_token)
40
 
41
  self.chat_manager.initialize_chat(message.sender_id)
42
 
 
45
  result = await process_message_with_llm(
46
  message.sender_id,
47
  message.content,
48
+ self.chat_manager.get_chat_history(message.sender_id),
49
+ whatsapp_token=whatsapp_token,
50
+ whatsapp_url=whatsapp_url,
51
  **media_paths
52
  )
53
 
 
62
  except Exception as e:
63
  return {"status": "error", "message_id": raw_message.get("id"), "error": str(e)}
64
 
65
+ async def _process_media(self, message: Message, whatsapp_token: str) -> Dict[str, Optional[str]]:
66
  media_paths = {
67
  "image_file_path": None,
68
  "doc_path": None,
 
76
  self.logger.info(f"Processing {media_type.value}: {content.file_path}")
77
  file_path = await self.media_handler.download(
78
  content.id,
79
+ whatsapp_token,
80
  content.file_path
81
  )
82
  self.logger.info(f"{media_type.value} file_path: {file_path}")
app/handlers/webhook_handler.py CHANGED
@@ -18,7 +18,7 @@ class WebhookHandler:
18
  self.message_handler = message_handler
19
  self.logger = logging.getLogger(__name__)
20
 
21
- async def process_webhook(self, payload: dict, access_token: str) -> WebhookResponse:
22
  request_id = f"req_{int(time.time()*1000)}"
23
  results = []
24
 
@@ -37,7 +37,9 @@ class WebhookHandler:
37
  self.logger.info(f"Processing message: {message}")
38
  response = await self.message_handler.handle(
39
  raw_message=message,
40
- access_token=access_token
 
 
41
  )
42
  results.append(response)
43
 
 
18
  self.message_handler = message_handler
19
  self.logger = logging.getLogger(__name__)
20
 
21
+ async def process_webhook(self, payload: dict, whatsapp_token: str, whatsapp_url:str,gemini_api:str) -> WebhookResponse:
22
  request_id = f"req_{int(time.time()*1000)}"
23
  results = []
24
 
 
37
  self.logger.info(f"Processing message: {message}")
38
  response = await self.message_handler.handle(
39
  raw_message=message,
40
+ whatsapp_token=whatsapp_token,
41
+ whatsapp_url=whatsapp_url,
42
+ gemini_api=gemini_api,
43
  )
44
  results.append(response)
45
 
app/main.py CHANGED
@@ -21,7 +21,7 @@ from app.handlers.webhook_handler import WebhookHandler
21
  from app.handlers.media_handler import WhatsAppMediaHandler
22
  from app.services.cache import MessageCache
23
  from app.services.chat_manager import ChatManager
24
-
25
  from app.utils.load_env import ACCESS_TOKEN
26
 
27
  # Configure logging
@@ -79,7 +79,7 @@ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
79
  app.add_middleware(SlowAPIMiddleware)
80
 
81
  # app.include_router(users.router, prefix="/users", tags=["Users"])
82
-
83
 
84
  # Prometheus metrics
85
  webhook_requests = Counter('webhook_requests_total', 'Total webhook requests')
@@ -105,11 +105,19 @@ async def webhook(request: Request):
105
  # Process the webhook payload here
106
  # For example:
107
  # results = process_webhook_entries(validated_payload.entry)
 
 
 
 
 
 
 
108
  response = await webhook_handler.process_webhook(
109
  payload=payload,
110
- access_token=ACCESS_TOKEN
 
 
111
  )
112
-
113
  return JSONResponse(
114
  content=response.__dict__,
115
  status_code=status.HTTP_200_OK
 
21
  from app.handlers.media_handler import WhatsAppMediaHandler
22
  from app.services.cache import MessageCache
23
  from app.services.chat_manager import ChatManager
24
+ from app.api.api_prompt import prompt_router
25
  from app.utils.load_env import ACCESS_TOKEN
26
 
27
  # Configure logging
 
79
  app.add_middleware(SlowAPIMiddleware)
80
 
81
  # app.include_router(users.router, prefix="/users", tags=["Users"])
82
+ app.include_router(prompt_router, prefix="/prompts", tags=["Prompts"])
83
 
84
  # Prometheus metrics
85
  webhook_requests = Counter('webhook_requests_total', 'Total webhook requests')
 
105
  # Process the webhook payload here
106
  # For example:
107
  # results = process_webhook_entries(validated_payload.entry)
108
+ # e.g., whatsapp_token, verify_token, llm_api_key, llm_model
109
+ whatsapp_token = request.query_params.get("whatsapp_token")
110
+ whatsapp_url = request.query_params.get("whatsapp_url")
111
+ gemini_api = request.query_params.get("gemini_api")
112
+ llm_model = request.query_params.get("cx_code")
113
+
114
+ print(f"payload: {payload}")
115
  response = await webhook_handler.process_webhook(
116
  payload=payload,
117
+ whatsapp_token=whatsapp_token,
118
+ whatsapp_url=whatsapp_url,
119
+ gemini_api=gemini_api,
120
  )
 
121
  return JSONResponse(
122
  content=response.__dict__,
123
  status_code=status.HTTP_200_OK
app/search/bm25_search.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # bm25_search.py
2
+ import asyncio
3
+ from rank_bm25 import BM25Okapi
4
+ import nltk
5
+ import string
6
+ from typing import List, Set, Optional
7
+ from nltk.corpus import stopwords
8
+ from nltk.stem import WordNetLemmatizer
9
+
10
+
11
+ def download_nltk_resources():
12
+ """
13
+ Downloads required NLTK resources synchronously.
14
+ """
15
+ resources = ['punkt', 'stopwords', 'wordnet', 'omw-1.4']
16
+ for resource in resources:
17
+ try:
18
+ nltk.download(resource, quiet=True)
19
+ except Exception as e:
20
+ print(f"Error downloading {resource}: {str(e)}")
21
+
22
+ class BM25_search:
23
+ # Class variable to track if resources have been downloaded
24
+ nltk_resources_downloaded = False
25
+
26
+ def __init__(self, remove_stopwords: bool = True, perform_lemmatization: bool = False):
27
+ """
28
+ Initializes the BM25search.
29
+
30
+ Parameters:
31
+ - remove_stopwords (bool): Whether to remove stopwords during preprocessing.
32
+ - perform_lemmatization (bool): Whether to perform lemmatization on tokens.
33
+ """
34
+ # Ensure NLTK resources are downloaded only once
35
+ if not BM25_search.nltk_resources_downloaded:
36
+ download_nltk_resources()
37
+ BM25_search.nltk_resources_downloaded = True # Mark as downloaded
38
+
39
+ self.documents: List[str] = []
40
+ self.doc_ids: List[str] = []
41
+ self.tokenized_docs: List[List[str]] = []
42
+ self.bm25: Optional[BM25Okapi] = None
43
+ self.remove_stopwords = remove_stopwords
44
+ self.perform_lemmatization = perform_lemmatization
45
+ self.stop_words: Set[str] = set(stopwords.words('english')) if remove_stopwords else set()
46
+ self.lemmatizer = WordNetLemmatizer() if perform_lemmatization else None
47
+
48
+ def preprocess(self, text: str) -> List[str]:
49
+ """
50
+ Preprocesses the input text by lowercasing, removing punctuation,
51
+ tokenizing, removing stopwords, and optionally lemmatizing.
52
+ """
53
+ text = text.lower().translate(str.maketrans('', '', string.punctuation))
54
+ tokens = nltk.word_tokenize(text)
55
+ if self.remove_stopwords:
56
+ tokens = [token for token in tokens if token not in self.stop_words]
57
+ if self.perform_lemmatization and self.lemmatizer:
58
+ tokens = [self.lemmatizer.lemmatize(token) for token in tokens]
59
+ return tokens
60
+
61
+ def add_document(self, doc_id: str, new_doc: str) -> None:
62
+ """
63
+ Adds a new document to the corpus and updates the BM25 index.
64
+ """
65
+ processed_tokens = self.preprocess(new_doc)
66
+
67
+ self.documents.append(new_doc)
68
+ self.doc_ids.append(doc_id)
69
+ self.tokenized_docs.append(processed_tokens)
70
+ # Ensure update_bm25 is awaited if required in async context
71
+ self.update_bm25()
72
+ print(f"Added document ID: {doc_id}")
73
+
74
+ async def remove_document(self, index: int) -> None:
75
+ """
76
+ Removes a document from the corpus based on its index and updates the BM25 index.
77
+ """
78
+ if 0 <= index < len(self.documents):
79
+ removed_doc_id = self.doc_ids[index]
80
+ del self.documents[index]
81
+ del self.doc_ids[index]
82
+ del self.tokenized_docs[index]
83
+ self.update_bm25()
84
+ print(f"Removed document ID: {removed_doc_id}")
85
+ else:
86
+ print(f"Index {index} is out of bounds.")
87
+
88
+ def update_bm25(self) -> None:
89
+ """
90
+ Updates the BM25 index based on the current tokenized documents.
91
+ """
92
+ if self.tokenized_docs:
93
+ self.bm25 = BM25Okapi(self.tokenized_docs)
94
+ print("BM25 index has been initialized.")
95
+ else:
96
+ print("No documents to initialize BM25.")
97
+
98
+
99
+ def get_scores(self, query: str) -> List[float]:
100
+ """
101
+ Computes BM25 scores for all documents based on the given query.
102
+ """
103
+ processed_query = self.preprocess(query)
104
+ print(f"Tokenized Query: {processed_query}")
105
+
106
+ if self.bm25:
107
+ return self.bm25.get_scores(processed_query)
108
+ else:
109
+ print("BM25 is not initialized.")
110
+ return []
111
+
112
+ def get_top_n_docs(self, query: str, n: int = 5) -> List[str]:
113
+ """
114
+ Returns the top N documents for a given query.
115
+ """
116
+ processed_query = self.preprocess(query)
117
+ if self.bm25:
118
+ return self.bm25.get_top_n(processed_query, self.documents, n)
119
+ else:
120
+ print("initialized.")
121
+ return []
122
+
123
+ def clear_documents(self) -> None:
124
+ """
125
+ Clears all documents from the BM25 index.
126
+ """
127
+ self.documents = []
128
+ self.doc_ids = []
129
+ self.tokenized_docs = []
130
+ self.bm25 = None # Reset BM25 index
131
+ print("BM25 documents cleared and index reset.")
132
+
133
+ def get_document(self, doc_id: str) -> str:
134
+ """
135
+ Retrieves a document by its document ID.
136
+
137
+ Parameters:
138
+ - doc_id (str): The ID of the document to retrieve.
139
+
140
+ Returns:
141
+ - str: The document text if found, otherwise an empty string.
142
+ """
143
+ try:
144
+ index = self.doc_ids.index(doc_id)
145
+ return self.documents[index]
146
+ except ValueError:
147
+ print(f"Document ID {doc_id} not found.")
148
+ return ""
149
+
150
+
151
+ async def initialize_bm25_search(remove_stopwords: bool = True, perform_lemmatization: bool = False) -> BM25_search:
152
+ """
153
+ Initializes the BM25search with proper NLTK resource downloading.
154
+ """
155
+ loop = asyncio.get_running_loop()
156
+ await loop.run_in_executor(None, download_nltk_resources)
157
+ return BM25_search(remove_stopwords, perform_lemmatization)
158
+
159
+
app/search/faiss_search.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # faiss_wrapper.py
2
+ import faiss
3
+ import numpy as np
4
+
5
+ class FAISS_search:
6
+ def __init__(self, embedding_model):
7
+ self.documents = []
8
+ self.doc_ids = []
9
+ self.embedding_model = embedding_model
10
+ self.dimension = len(embedding_model.encode("embedding"))
11
+ self.index = faiss.IndexIDMap(faiss.IndexFlatL2(self.dimension))
12
+
13
+ def add_document(self, doc_id, new_doc):
14
+ self.documents.append(new_doc)
15
+ self.doc_ids.append(doc_id)
16
+ # Encode and add document with its index as ID
17
+ embedding = self.embedding_model.encode([new_doc], convert_to_numpy=True).astype('float32')
18
+
19
+ if embedding.size == 0:
20
+ print("No documents to add to FAISS index.")
21
+ return
22
+
23
+ idx = len(self.documents) - 1
24
+ id_array = np.array([idx]).astype('int64')
25
+ self.index.add_with_ids(embedding, id_array)
26
+
27
+ def remove_document(self, index):
28
+ if 0 <= index < len(self.documents):
29
+ del self.documents[index]
30
+ del self.doc_ids[index]
31
+ # Rebuild the index
32
+ self.build_index()
33
+ else:
34
+ print(f"Index {index} is out of bounds.")
35
+
36
+ def build_index(self):
37
+ embeddings = self.embedding_model.encode(self.documents, convert_to_numpy=True).astype('float32')
38
+ idx_array = np.arange(len(self.documents)).astype('int64')
39
+ self.index = faiss.IndexIDMap(faiss.IndexFlatL2(self.dimension))
40
+ self.index.add_with_ids(embeddings, idx_array)
41
+
42
+ def search(self, query, k):
43
+ if self.index.ntotal == 0:
44
+ # No documents in the index
45
+ print("FAISS index is empty. No results can be returned.")
46
+ return np.array([]), np.array([]) # Return empty arrays for distances and indices
47
+ query_embedding = self.embedding_model.encode([query], convert_to_numpy=True).astype('float32')
48
+ distances, indices = self.index.search(query_embedding, k)
49
+ return distances[0], indices[0]
50
+
51
+ def clear_documents(self) -> None:
52
+ """
53
+ Clears all documents from the FAISS index.
54
+ """
55
+ self.documents = []
56
+ self.doc_ids = []
57
+ # Reset the FAISS index
58
+ self.index = faiss.IndexIDMap(faiss.IndexFlatL2(self.dimension))
59
+ print("FAISS documents cleared and index reset.")
60
+
61
+ def get_document(self, doc_id: str) -> str:
62
+ """
63
+ Retrieves a document by its document ID.
64
+
65
+ Parameters:
66
+ - doc_id (str): The ID of the document to retrieve.
67
+
68
+ Returns:
69
+ - str: The document text if found, otherwise an empty string.
70
+ """
71
+ try:
72
+ index = self.doc_ids.index(doc_id)
73
+ return self.documents[index]
74
+ except ValueError:
75
+ print(f"Document ID {doc_id} not found.")
76
+ return ""
app/search/hybrid_search.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import logging, torch
3
+ from sklearn.preprocessing import MinMaxScaler
4
+ from sentence_transformers import CrossEncoder
5
+ # from FlagEmbedding import FlagReranker
6
+
7
+
8
+ class Hybrid_search:
9
+ def __init__(self, bm25_search, faiss_search, reranker_model_name="BAAI/bge-reranker-v2-gemma", initial_bm25_weight=0.5):
10
+ self.bm25_search = bm25_search
11
+ self.faiss_search = faiss_search
12
+ self.bm25_weight = initial_bm25_weight
13
+ # self.reranker = FlagReranker(reranker_model_name, use_fp16=True)
14
+ self.logger = logging.getLogger(__name__)
15
+
16
+ async def advanced_search(self, query, keywords, top_n=5, threshold=0.53, prefixes=None):
17
+ # Dynamic BM25 weighting
18
+ self._dynamic_weighting(len(query.split()))
19
+ keywords = f"{' '.join(keywords)}"
20
+ self.logger.info(f"Query: {query}")
21
+ self.logger.info(f"Keywords: {keywords}")
22
+
23
+ # Get BM25 scores and doc_ids
24
+ bm25_scores, bm25_doc_ids = self._get_bm25_results(keywords, top_n = top_n)
25
+ # self.logger.info(f"BM25 Scores: {bm25_scores}, BM25 doc_ids: {bm25_doc_ids}")
26
+ # Get FAISS distances, indices, and doc_ids
27
+ faiss_distances, faiss_indices, faiss_doc_ids = self._get_faiss_results(query)
28
+ try:
29
+ faiss_distances, indices, faiss_doc_ids = self._get_faiss_results(query, top_n = top_n)
30
+ # for dist, idx, doc_id in zip(faiss_distances, indices, faiss_doc_ids):
31
+ # print(f"Distance: {dist:.4f}, Index: {idx}, Doc ID: {doc_id}")
32
+ except Exception as e:
33
+ self.logger.error(f"Search failed: {str(e)}")
34
+ # Map doc_ids to scores
35
+ bm25_scores_dict, faiss_scores_dict = self._map_scores_to_doc_ids(
36
+ bm25_doc_ids, bm25_scores, faiss_doc_ids, faiss_distances
37
+ )
38
+ # Create a unified set of doc IDs
39
+ all_doc_ids = sorted(set(bm25_doc_ids).union(faiss_doc_ids))
40
+ # print(f"All doc_ids: {all_doc_ids}, BM25 doc_ids: {bm25_doc_ids}, FAISS doc_ids: {faiss_doc_ids}")
41
+
42
+ # Filter doc_ids based on prefixes
43
+ filtered_doc_ids = self._filter_doc_ids_by_prefixes(all_doc_ids, prefixes)
44
+ # self.logger.info(f"Filtered doc_ids: {filtered_doc_ids}")
45
+
46
+ if not filtered_doc_ids:
47
+ self.logger.info("No documents match the prefixes.")
48
+ return []
49
+
50
+ # Prepare score lists
51
+ filtered_bm25_scores, filtered_faiss_scores = self._get_filtered_scores(
52
+ filtered_doc_ids, bm25_scores_dict, faiss_scores_dict
53
+ )
54
+ # self.logger.info(f"Filtered BM25 scores: {filtered_bm25_scores}")
55
+ # self.logger.info(f"Filtered FAISS scores: {filtered_faiss_scores}")
56
+
57
+
58
+ # Normalize scores
59
+ bm25_scores_normalized, faiss_scores_normalized = self._normalize_scores(
60
+ filtered_bm25_scores, filtered_faiss_scores
61
+ )
62
+
63
+ # Calculate hybrid scores
64
+ hybrid_scores = self._calculate_hybrid_scores(bm25_scores_normalized, faiss_scores_normalized)
65
+
66
+ # Display hybrid scores
67
+ for idx, doc_id in enumerate(filtered_doc_ids):
68
+ print(f"Hybrid Score: {hybrid_scores[idx]:.4f}, Doc ID: {doc_id}")
69
+
70
+ # Apply threshold and get top_n results
71
+ results = self._get_top_n_results(filtered_doc_ids, hybrid_scores, top_n, threshold)
72
+ # self.logger.info(f"Results before reranking: {results}")
73
+
74
+ # If results exist, apply re-ranking
75
+ # if results:
76
+ # re_ranked_results = self._rerank_results(query, results)
77
+ # self.logger.info(f"Results after reranking: {re_ranked_results}")
78
+ # return re_ranked_results
79
+
80
+ return results
81
+
82
+
83
+ def _dynamic_weighting(self, query_length):
84
+ if query_length <= 5:
85
+ self.bm25_weight = 0.7
86
+ else:
87
+ self.bm25_weight = 0.5
88
+ self.logger.info(f"Dynamic BM25 weight set to: {self.bm25_weight}")
89
+
90
+ def _get_bm25_results(self, keywords, top_n:int = None):
91
+ # Get BM25 scores
92
+ bm25_scores = np.array(self.bm25_search.get_scores(keywords))
93
+ bm25_doc_ids = np.array(self.bm25_search.doc_ids) # Assuming doc_ids is a list of document IDs
94
+
95
+ # Log the scores and IDs before filtering
96
+ # self.logger.info(f"BM25 scores: {bm25_scores}")
97
+ # self.logger.info(f"BM25 doc_ids: {bm25_doc_ids}")
98
+
99
+ # Get the top k indices based on BM25 scores
100
+ top_k_indices = np.argsort(bm25_scores)[-top_n:][::-1]
101
+
102
+ # Retrieve top k scores and corresponding document IDs
103
+ top_k_scores = bm25_scores[top_k_indices]
104
+ top_k_doc_ids = bm25_doc_ids[top_k_indices]
105
+
106
+ # Return top k scores and document IDs
107
+ return top_k_scores, top_k_doc_ids
108
+
109
+ def _get_faiss_results(self, query, top_n: int = None) -> tuple[np.ndarray, np.ndarray, list[str]]:
110
+
111
+ try:
112
+ # If top_k is not specified, use all documents
113
+ if top_n is None:
114
+ top_n = len(self.faiss_search.doc_ids)
115
+
116
+ # Use the search's search method which handles the embedding
117
+ distances, indices = self.faiss_search.search(query, k=top_n)
118
+
119
+ if len(distances) == 0 or len(indices) == 0:
120
+ # Handle case where FAISS returns empty results
121
+ self.logger.info("FAISS search returned no results.")
122
+ return np.array([]), np.array([]), []
123
+
124
+ # Filter out invalid indices (-1)
125
+ valid_mask = indices != -1
126
+ filtered_distances = distances[valid_mask]
127
+ filtered_indices = indices[valid_mask]
128
+
129
+ # Map indices to doc_ids
130
+ doc_ids = [self.faiss_search.doc_ids[idx] for idx in filtered_indices
131
+ if 0 <= idx < len(self.faiss_search.doc_ids)]
132
+
133
+ # self.logger.info(f"FAISS distances: {filtered_distances}")
134
+ # self.logger.info(f"FAISS indices: {filtered_indices}")
135
+ # self.logger.info(f"FAISS doc_ids: {doc_ids}")
136
+
137
+ return filtered_distances, filtered_indices, doc_ids
138
+
139
+ except Exception as e:
140
+ self.logger.error(f"Error in FAISS search: {str(e)}")
141
+ raise
142
+
143
+ def _map_scores_to_doc_ids(self, bm25_doc_ids, bm25_scores, faiss_doc_ids, faiss_scores):
144
+ bm25_scores_dict = dict(zip(bm25_doc_ids, bm25_scores))
145
+ faiss_scores_dict = dict(zip(faiss_doc_ids, faiss_scores))
146
+ # self.logger.info(f"BM25 scores dict: {bm25_scores_dict}")
147
+ # self.logger.info(f"FAISS scores dict: {faiss_scores_dict}")
148
+ return bm25_scores_dict, faiss_scores_dict
149
+
150
+ def _filter_doc_ids_by_prefixes(self, all_doc_ids, prefixes):
151
+ if prefixes:
152
+ filtered_doc_ids = [
153
+ doc_id
154
+ for doc_id in all_doc_ids
155
+ if any(doc_id.startswith(prefix) for prefix in prefixes)
156
+ ]
157
+ else:
158
+ filtered_doc_ids = list(all_doc_ids)
159
+ return filtered_doc_ids
160
+
161
+ def _get_filtered_scores(self, filtered_doc_ids, bm25_scores_dict, faiss_scores_dict):
162
+ # Initialize lists to hold scores in the unified doc ID order
163
+ bm25_aligned_scores = []
164
+ faiss_aligned_scores = []
165
+
166
+ # Populate aligned score lists, filling missing scores with neutral values
167
+ for doc_id in filtered_doc_ids:
168
+ bm25_aligned_scores.append(bm25_scores_dict.get(doc_id, 0)) # Use 0 if not found in BM25
169
+ faiss_aligned_scores.append(faiss_scores_dict.get(doc_id, max(faiss_scores_dict.values()) + 1)) # Use a high distance if not found in FAISS
170
+
171
+ # Invert the FAISS scores
172
+ faiss_aligned_scores = [1 / score if score != 0 else 0 for score in faiss_aligned_scores]
173
+
174
+ return bm25_aligned_scores, faiss_aligned_scores
175
+
176
+ def _normalize_scores(self, filtered_bm25_scores, filtered_faiss_scores):
177
+ scaler_bm25 = MinMaxScaler()
178
+ bm25_scores_normalized = self._normalize_array(filtered_bm25_scores, scaler_bm25)
179
+
180
+ scaler_faiss = MinMaxScaler()
181
+ faiss_scores_normalized = self._normalize_array(filtered_faiss_scores, scaler_faiss)
182
+
183
+ # self.logger.info(f"Normalized BM25 scores: {bm25_scores_normalized}")
184
+ # self.logger.info(f"Normalized FAISS scores: {faiss_scores_normalized}")
185
+ return bm25_scores_normalized, faiss_scores_normalized
186
+
187
+ def _normalize_array(self, scores, scaler):
188
+ scores_array = np.array(scores)
189
+ if np.ptp(scores_array) > 0:
190
+ normalized_scores = scaler.fit_transform(scores_array.reshape(-1, 1)).flatten()
191
+ else:
192
+ # Handle identical scores with a fallback to uniform 0.5
193
+ normalized_scores = np.full_like(scores_array, 0.5, dtype=float)
194
+ return normalized_scores
195
+
196
+ def _calculate_hybrid_scores(self, bm25_scores_normalized, faiss_scores_normalized):
197
+ hybrid_scores = self.bm25_weight * bm25_scores_normalized + (1 - self.bm25_weight) * faiss_scores_normalized
198
+ # self.logger.info(f"Hybrid scores: {hybrid_scores}")
199
+ return hybrid_scores
200
+
201
+ def _get_top_n_results(self, filtered_doc_ids, hybrid_scores, top_n, threshold):
202
+ hybrid_scores = np.array(hybrid_scores)
203
+ threshold_indices = np.where(hybrid_scores >= threshold)[0]
204
+ if len(threshold_indices) == 0:
205
+ self.logger.info("No documents meet the threshold.")
206
+ return []
207
+
208
+ sorted_indices = threshold_indices[np.argsort(hybrid_scores[threshold_indices])[::-1]]
209
+ top_indices = sorted_indices[:top_n]
210
+
211
+ results = [(filtered_doc_ids[idx], hybrid_scores[idx]) for idx in top_indices]
212
+ self.logger.info(f"Top {top_n} results: {results}")
213
+ return results
214
+
215
+ def _rerank_results(self, query, results):
216
+ """
217
+ Re-rank the retrieved documents using FlagReranker with normalized scores.
218
+
219
+ Parameters:
220
+ - query (str): The search query.
221
+ - results (List[Tuple[str, float]]): A list of (doc_id, score) tuples.
222
+
223
+ Returns:
224
+ - List[Tuple[str, float]]: Re-ranked list of (doc_id, score) tuples with normalized scores.
225
+ """
226
+ # Prepare input for the re-ranker
227
+ document_texts = [self.bm25_search.get_document(doc_id) for doc_id, _ in results]
228
+ doc_ids = [doc_id for doc_id, _ in results]
229
+
230
+ # Generate pairwise scores using the FlagReranker
231
+ rerank_inputs = [[query, doc] for doc in document_texts]
232
+ with torch.no_grad():
233
+ rerank_scores = self.reranker.compute_score(rerank_inputs, normalize=True)
234
+
235
+ # rerank_scores = self.reranker.compute_score(rerank_inputs, normalize=True)
236
+
237
+ # Combine doc_ids with normalized re-rank scores and sort by scores
238
+ reranked_results = sorted(
239
+ zip(doc_ids, rerank_scores),
240
+ key=lambda x: x[1],
241
+ reverse=True
242
+ )
243
+
244
+ # Log and return results
245
+ # self.logger.info(f"Re-ranked results with normalized scores: {reranked_results}")
246
+ return reranked_results
247
+
app/search/rag_pipeline.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag_pipeline.py
2
+
3
+ import numpy as np
4
+
5
+ import pickle
6
+ import os
7
+ import logging
8
+ import asyncio
9
+
10
+ from app.search.bm25_search import BM25_search
11
+ from app.search.faiss_search import FAISS_search
12
+ from app.search.hybrid_search import Hybrid_search
13
+ from app.utils.token_counter import TokenCounter
14
+
15
+
16
+ # Configure logging
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # rag.py
21
+ class RAGSystem:
22
+ def __init__(self, embedding_model):
23
+ self.token_counter = TokenCounter()
24
+ self.documents = []
25
+ self.doc_ids = []
26
+ self.results = []
27
+ self.meta_data = []
28
+ self.embedding_model = embedding_model
29
+ self.bm25_wrapper = BM25_search()
30
+ self.faiss_wrapper = FAISS_search(embedding_model)
31
+ self.hybrid_search = Hybrid_search(self.bm25_wrapper, self.faiss_wrapper)
32
+
33
+ def add_document(self, doc_id, text, metadata=None):
34
+ self.token_counter.add_document(doc_id, text)
35
+ self.doc_ids.append(doc_id)
36
+ self.documents.append(text)
37
+ self.meta_data.append(metadata)
38
+ self.bm25_wrapper.add_document(doc_id, text)
39
+ self.faiss_wrapper.add_document(doc_id, text)
40
+
41
+ def delete_document(self, doc_id):
42
+ try:
43
+ index = self.doc_ids.index(doc_id)
44
+ del self.doc_ids[index]
45
+ del self.documents[index]
46
+ self.bm25_wrapper.remove_document(index)
47
+ self.faiss_wrapper.remove_document(index)
48
+ self.token_counter.remove_document(doc_id)
49
+ except ValueError:
50
+ logging.warning(f"Document ID {doc_id} not found.")
51
+
52
+ async def adv_query(self, query_text, keywords, top_k=15, prefixes=None):
53
+ results = await self.hybrid_search.advanced_search(
54
+ query_text,
55
+ keywords=keywords,
56
+ top_n=top_k,
57
+ threshold=0.43,
58
+ prefixes=prefixes
59
+ )
60
+ retrieved_docs = []
61
+ if results:
62
+ seen_docs = set()
63
+ for doc_id, score in results:
64
+ if doc_id not in seen_docs:
65
+ # Check if the doc_id exists in self.doc_ids
66
+ if doc_id not in self.doc_ids:
67
+ logger.error(f"doc_id {doc_id} not found in self.doc_ids")
68
+ seen_docs.add(doc_id)
69
+
70
+ # Fetch the index of the document
71
+ try:
72
+ index = self.doc_ids.index(doc_id)
73
+ except ValueError as e:
74
+ logger.error(f"Error finding index for doc_id {doc_id}: {e}")
75
+ continue
76
+
77
+ # Validate index range
78
+ if index >= len(self.documents) or index >= len(self.meta_data):
79
+ logger.error(f"Index {index} out of range for documents or metadata")
80
+ continue
81
+
82
+ doc = self.documents[index]
83
+
84
+ # meta_data = self.meta_data[index]
85
+ # Extract the file name and page number
86
+ # file_name = meta_data['source'].split('/')[-1] # Extracts 'POJK 31 - 2018.pdf'
87
+ # page_number = meta_data.get('page', 'unknown')
88
+ # url = meta_data['source']
89
+ # file_name = meta_data.get('source', 'unknown_source').split('/')[-1] # Safe extraction
90
+ # page_number = meta_data.get('page', 'unknown') # Default to 'unknown' if 'page' is missing
91
+ # url = meta_data.get('source', 'unknown_url') # Default URL fallback
92
+
93
+ # logger.info(f"file_name: {file_name}, page_number: {page_number}, url: {url}")
94
+
95
+ # Format as a single string
96
+ # content_string = f"'{file_name}', 'page': {page_number}"
97
+ # doc_name = f"{file_name}"
98
+
99
+ self.results.append(doc)
100
+ retrieved_docs.append({"text": doc})
101
+ return retrieved_docs
102
+ else:
103
+ return [{"name": "No relevant documents found.", "text": None}]
104
+
105
+ def get_total_tokens(self):
106
+ return self.token_counter.get_total_tokens()
107
+ def get_context(self):
108
+ context = "\n".join(self.results)
109
+ return context
110
+
111
+ def save_state(self, path):
112
+ # Save doc_ids, documents, and token counter state
113
+ with open(f"{path}_state.pkl", 'wb') as f:
114
+ pickle.dump({
115
+ "doc_ids": self.doc_ids,
116
+ "documents": self.documents,
117
+ "meta_data": self.meta_data,
118
+ "token_counts": self.token_counter.doc_tokens
119
+ }, f)
120
+
121
+ def load_state(self, path):
122
+ if os.path.exists(f"{path}_state.pkl"):
123
+ with open(f"{path}_state.pkl", 'rb') as f:
124
+ state_data = pickle.load(f)
125
+ self.doc_ids = state_data["doc_ids"]
126
+ self.documents = state_data["documents"]
127
+ self.meta_data = state_data["meta_data"]
128
+ self.token_counter.doc_tokens = state_data["token_counts"]
129
+
130
+ # Clear and rebuild BM25 and FAISS
131
+ self.bm25_wrapper.clear_documents()
132
+ self.faiss_wrapper.clear_documents()
133
+ for doc_id, document in zip(self.doc_ids, self.documents):
134
+ self.bm25_wrapper.add_document(doc_id, document)
135
+ self.faiss_wrapper.add_document(doc_id, document)
136
+
137
+ self.token_counter.total_tokens = sum(self.token_counter.doc_tokens.values())
138
+ logging.info("System state loaded successfully with documents and indices rebuilt.")
139
+ else:
140
+ logging.info("No previous state found, initializing fresh state.")
141
+ self.doc_ids = []
142
+ self.documents = []
143
+ self.meta_data = [] # Reset meta_data
144
+ self.token_counter = TokenCounter()
145
+ self.bm25_wrapper = BM25_search()
146
+ self.faiss_wrapper = FAISS_search(self.embedding_model)
147
+ self.hybrid_search = Hybrid_search(self.bm25_wrapper, self.faiss_wrapper)
app/services/message.py CHANGED
@@ -6,7 +6,9 @@ from datetime import datetime
6
  import logging
7
  import asyncio
8
  from openai import AsyncOpenAI
 
9
  import google.generativeai as genai
 
10
  import PIL.Image
11
  from typing import List, Dict, Any, Optional
12
 
@@ -17,6 +19,29 @@ from app.services.search_engine import google_search
17
  # Load environment variables
18
  load_dotenv()
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  genai.configure(api_key=GEMNI_API)
21
  client = AsyncOpenAI(api_key = OPENAI_API)
22
  # Configure logging
@@ -31,9 +56,9 @@ if not WHATSAPP_API_URL or not ACCESS_TOKEN:
31
  logger.warning("Environment variables for WHATSAPP_API_URL or ACCESS_TOKEN are not set!")
32
 
33
  # Helper function to send a reply
34
- async def send_reply(to: str, body: str) -> Dict[str, Any]:
35
  headers = {
36
- "Authorization": f"Bearer {ACCESS_TOKEN}",
37
  "Content-Type": "application/json"
38
  }
39
  data = {
@@ -46,7 +71,7 @@ async def send_reply(to: str, body: str) -> Dict[str, Any]:
46
  }
47
 
48
  async with httpx.AsyncClient() as client:
49
- response = await client.post(WHATSAPP_API_URL, json=data, headers=headers)
50
 
51
  if response.status_code != 200:
52
  error_detail = response.json()
@@ -74,7 +99,9 @@ async def generate_reply(sender: str, content: str, timestamp: int) -> str:
74
  async def process_message_with_llm(
75
  sender_id: str,
76
  content: str,
77
- history: List[Dict[str, str]],
 
 
78
  image_file_path: Optional[str] = None,
79
  doc_path: Optional[str] = None,
80
  video_file_path: Optional[str] = None,
@@ -92,7 +119,7 @@ async def process_message_with_llm(
92
  )
93
  logger.info(f"Generated reply: {generated_reply}")
94
 
95
- response = await send_reply(sender_id, generated_reply)
96
  # return generated_reply
97
  return generated_reply
98
  except Exception as e:
@@ -140,14 +167,55 @@ async def generate_response_from_gemini(
140
  pass # Placeholder for video processing logic
141
 
142
  # Send the user's message
143
- response = await chat.send_message_async(content, tools=[google_search])
 
144
  return response.text
145
 
146
  except Exception as e:
147
  logger.error("Error in generate_response_from_gemini:", exc_info=True)
148
  return "Sorry, I couldn't generate a response at this time."
149
 
150
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  # Process message with retry logic
153
  # async def process_message_with_retry(
 
6
  import logging
7
  import asyncio
8
  from openai import AsyncOpenAI
9
+ import json
10
  import google.generativeai as genai
11
+
12
  import PIL.Image
13
  from typing import List, Dict, Any, Optional
14
 
 
19
  # Load environment variables
20
  load_dotenv()
21
 
22
+ # Define function specifications for Gemini
23
+ function_declarations = [
24
+ {
25
+ "name": "google_search",
26
+ "description": "Perform a Google search and retrieve search results",
27
+ "parameters": {
28
+ "type": "object",
29
+ "properties": {
30
+ "query": {
31
+ "type": "string",
32
+ "description": "The search query to perform"
33
+ },
34
+ "num_results": {
35
+ "type": "string",
36
+ "description": "Number of search results to retrieve (1-10)",
37
+ "default": "3"
38
+ }
39
+ },
40
+ "required": ["query"]
41
+ }
42
+ }
43
+ ]
44
+
45
  genai.configure(api_key=GEMNI_API)
46
  client = AsyncOpenAI(api_key = OPENAI_API)
47
  # Configure logging
 
56
  logger.warning("Environment variables for WHATSAPP_API_URL or ACCESS_TOKEN are not set!")
57
 
58
  # Helper function to send a reply
59
+ async def send_reply(to: str, body: str, whatsapp_token: str, whatsapp_url:str) -> Dict[str, Any]:
60
  headers = {
61
+ "Authorization": f"Bearer {whatsapp_token}",
62
  "Content-Type": "application/json"
63
  }
64
  data = {
 
71
  }
72
 
73
  async with httpx.AsyncClient() as client:
74
+ response = await client.post(whatsapp_url, json=data, headers=headers)
75
 
76
  if response.status_code != 200:
77
  error_detail = response.json()
 
99
  async def process_message_with_llm(
100
  sender_id: str,
101
  content: str,
102
+ history: List[Dict[str, str]],
103
+ whatsapp_token: str,
104
+ whatsapp_url:str,
105
  image_file_path: Optional[str] = None,
106
  doc_path: Optional[str] = None,
107
  video_file_path: Optional[str] = None,
 
119
  )
120
  logger.info(f"Generated reply: {generated_reply}")
121
 
122
+ response = await send_reply(sender_id, generated_reply, whatsapp_token, whatsapp_url)
123
  # return generated_reply
124
  return generated_reply
125
  except Exception as e:
 
167
  pass # Placeholder for video processing logic
168
 
169
  # Send the user's message
170
+ response = await chat.send_message_async(content)
171
+ # response = await handle_function_call(response)
172
  return response.text
173
 
174
  except Exception as e:
175
  logger.error("Error in generate_response_from_gemini:", exc_info=True)
176
  return "Sorry, I couldn't generate a response at this time."
177
 
178
+ async def handle_function_call(chat):
179
+ """
180
+ Handle function calls from the Gemini API.
181
+
182
+ Args:
183
+ chat (ChatSession): The current chat session.
184
+
185
+ Returns:
186
+ The response after resolving function calls.
187
+ """
188
+ # Continue the conversation and handle any function calls
189
+ while True:
190
+ response = chat.send_message_async(chat.history[-1])
191
+
192
+ # Check if there are any function calls to handle
193
+ if response.candidates[0].content.parts[0].function_call:
194
+ function_call = response.candidates[0].content.parts[0].function_call
195
+ function_name = function_call.name
196
+ function_args = json.loads(function_call.args)
197
+
198
+ # Dispatch to the appropriate function
199
+ if function_name == "google_search":
200
+ # Handle async function call
201
+ result = await google_search(
202
+ query=function_args['query'],
203
+ num_results=function_args.get('num_results', '3')
204
+ )
205
+
206
+
207
+ # Send the function result back to continue the conversation
208
+ response = chat.send_message_async(
209
+ part={
210
+ "function_response": {
211
+ "name": function_name,
212
+ "response": result
213
+ }
214
+ }
215
+ )
216
+ else:
217
+ # No more function calls, return the final response
218
+ return response
219
 
220
  # Process message with retry logic
221
  # async def process_message_with_retry(
app/services/search_engine.py CHANGED
@@ -60,7 +60,8 @@ async def google_search(query: str, num_results: str = "3") -> Optional[List[Dic
60
  if response.status_code == 200:
61
  results = response.json()
62
  items = results.get("items", [])
63
- return [{"title": item["title"], "link": item["link"]} for item in items]
 
64
 
65
  else:
66
  logger.error(f"Google Search API error: {response.status_code} - {response.text}")
@@ -70,3 +71,18 @@ async def google_search(query: str, num_results: str = "3") -> Optional[List[Dic
70
  logger.error("A network error occurred while performing the Google search.")
71
  logger.error(f"Error details: {e}")
72
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  if response.status_code == 200:
61
  results = response.json()
62
  items = results.get("items", [])
63
+
64
+ return [{"title": item["title"], "link": item["link"], "snippet": item["snippet"]} for item in items]
65
 
66
  else:
67
  logger.error(f"Google Search API error: {response.status_code} - {response.text}")
 
71
  logger.error("A network error occurred while performing the Google search.")
72
  logger.error(f"Error details: {e}")
73
  return None
74
+
75
+ def set_light_values(brightness: str, color_temp: str) -> Dict[str, str]:
76
+ """Set the brightness and color temperature of a room light. (mock API).
77
+
78
+ Args:
79
+ brightness: Light level from 0 to 100. Zero is off and 100 is full brightness
80
+ color_temp: Color temperature of the light fixture, which can be `daylight`, `cool` or `warm`.
81
+
82
+ Returns:
83
+ A dictionary containing the set brightness and color temperature.
84
+ """
85
+ return {
86
+ "brightness": brightness,
87
+ "colorTemperature": color_temp
88
+ }
app/utils/load_env.py CHANGED
@@ -20,8 +20,8 @@ CX_CODE = os.getenv("CX_CODE")
20
  CUSTOM_SEARCH_API_KEY = os.getenv("CUSTOM_SEARCH_API_KEY")
21
 
22
  # Debugging: Print the retrieved ACCESS_TOKEN (for development only)
23
- if ENV == "development":
24
- print(f"ACCESS_TOKEN loaded: {ACCESS_TOKEN}")
25
 
26
 
27
 
 
20
  CUSTOM_SEARCH_API_KEY = os.getenv("CUSTOM_SEARCH_API_KEY")
21
 
22
  # Debugging: Print the retrieved ACCESS_TOKEN (for development only)
23
+ # if ENV == "development":
24
+ # print(f"ACCESS_TOKEN loaded: {ACCESS_TOKEN}")
25
 
26
 
27
 
app/utils/system_prompt.py CHANGED
@@ -22,4 +22,10 @@ Example Interactions:
22
  If a user asks, “Are there any issues with the city government's policies?” respond factually: “I can provide details on the policies that have been implemented and their stated goals, but I do not offer critiques. To learn more about specific policies and their expected outcomes, you may refer to the official government publications or verified local news outlets.”
23
 
24
  By following these guidelines, you will serve as a reliable, respectful, and informative resource for users looking to understand the latest happenings in Surabaya without engaging in criticism of the government.
 
 
 
 
 
 
25
  """
 
22
  If a user asks, “Are there any issues with the city government's policies?” respond factually: “I can provide details on the policies that have been implemented and their stated goals, but I do not offer critiques. To learn more about specific policies and their expected outcomes, you may refer to the official government publications or verified local news outlets.”
23
 
24
  By following these guidelines, you will serve as a reliable, respectful, and informative resource for users looking to understand the latest happenings in Surabaya without engaging in criticism of the government.
25
+ """
26
+
27
+ agentic_prompt = """ You are a helpful assistant and have capabilities to search the web.
28
+ When you the links are given, you should summarize the content of the link and give a short summary.
29
+ You should also include the source of the link in the summary.
30
+
31
  """
app/utils/token_counter.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # token_counter.py
2
+
3
+ import os
4
+ import tiktoken
5
+
6
+ # Choose the encoding based on your model, e.g., 'cl100k_base' for OpenAI models
7
+ encoding = tiktoken.get_encoding("cl100k_base")
8
+
9
+ def count_tokens(text):
10
+ tokens = encoding.encode(text)
11
+ return len(tokens)
12
+
13
+ class TokenCounter:
14
+ def __init__(self):
15
+ self.total_tokens = 0
16
+ self.doc_tokens = {}
17
+
18
+ def add_document(self, doc_id, text):
19
+ num_tokens = count_tokens(text)
20
+ self.doc_tokens[doc_id] = num_tokens
21
+ self.total_tokens += num_tokens
22
+
23
+ def remove_document(self, doc_id):
24
+ if doc_id in self.doc_tokens:
25
+ self.total_tokens -= self.doc_tokens[doc_id]
26
+ del self.doc_tokens[doc_id]
27
+
28
+ def get_total_tokens(self):
29
+ return self.total_tokens
app/utils/tool_call_extractor.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ from typing import List, Dict, Any, Optional
4
+
5
+ class ToolCallExtractor:
6
+ def __init__(self):
7
+ # Existing regex patterns (retain if needed for other formats)
8
+ self.complete_pattern = re.compile(r'<\|python_tag\|>(.*?)<\|eom_id\|>', re.DOTALL)
9
+ self.partial_pattern = re.compile(r'(.*?)<\|eom_id\|>', re.DOTALL)
10
+
11
+ def _extract_function_args(self, args) -> Dict[str, Any]:
12
+ """
13
+ Flatten the nested function args structure for Google AI protobuf types.
14
+ """
15
+ flattened_args = {}
16
+
17
+ try:
18
+ # Explicitly check for fields
19
+ if hasattr(args, 'fields'):
20
+ # Iterate through fields using to_dict() to convert protobuf to dict
21
+ for field in args.fields:
22
+ key = field.key
23
+ value = field.value
24
+
25
+ # Additional debugging
26
+ print(f"Field key: {key}")
27
+ print(f"Field value type: {type(value)}")
28
+ print(f"Field value: {value}")
29
+
30
+ # Extract string value
31
+ if hasattr(value, 'string_value'):
32
+ flattened_args[key] = value.string_value
33
+ print(f"Extracted string value: {value.string_value}")
34
+ elif hasattr(value, 'number_value'):
35
+ flattened_args[key] = value.number_value
36
+ elif hasattr(value, 'bool_value') and value.bool_value is not None:
37
+ flattened_args[key] = value.bool_value
38
+
39
+ # Added additional debug information
40
+ print(f"Final flattened args: {flattened_args}")
41
+
42
+ except Exception as e:
43
+ print(f"Error extracting function args: {e}")
44
+
45
+ return flattened_args
46
+
47
+ def extract_tool_calls(self, input_string: str) -> List[Dict[str, Any]]:
48
+ """
49
+ Extract tool calls from input string, handling various inconsistent formats.
50
+
51
+ Args:
52
+ input_string (str): The input string containing tool calls.
53
+
54
+ Returns:
55
+ list: A list of dictionaries representing the parsed tool calls.
56
+ """
57
+ tool_calls = []
58
+
59
+ # Existing tag-based extraction (retain if needed)
60
+ complete_matches = self.complete_pattern.findall(input_string)
61
+ if complete_matches:
62
+ for match in complete_matches:
63
+ tool_calls.extend(self._extract_json_objects(match))
64
+ return tool_calls
65
+
66
+ partial_matches = self.partial_pattern.findall(input_string)
67
+ if partial_matches:
68
+ for match in partial_matches:
69
+ tool_calls.extend(self._extract_json_objects(match))
70
+ return tool_calls
71
+
72
+ # Fallback: Attempt to parse the entire string
73
+ tool_calls.extend(self._extract_json_objects(input_string))
74
+
75
+ return tool_calls
76
+
77
+ def _extract_json_objects(self, text: str) -> List[Dict[str, Any]]:
78
+ """
79
+ Extract and parse multiple JSON objects from a string.
80
+ """
81
+ json_objects = []
82
+ potential_jsons = text.split(';')
83
+
84
+ for json_str in potential_jsons:
85
+ parsed_obj = self._clean_and_parse_json(json_str)
86
+ if parsed_obj:
87
+ json_objects.append(parsed_obj)
88
+
89
+ return json_objects
90
+
91
+ def _clean_and_parse_json(self, json_str: str) -> Optional[Dict[str, Any]]:
92
+ """
93
+ Clean and parse a JSON string, handling common formatting issues.
94
+ """
95
+ try:
96
+ json_str = json_str.strip()
97
+ if json_str.startswith('{') or json_str.startswith('['):
98
+ return json.loads(json_str)
99
+ return None
100
+ except json.JSONDecodeError:
101
+ return None
102
+
103
+ def validate_tool_call(self, tool_call: Dict[str, Any]) -> bool:
104
+ """
105
+ Validate if a tool call has the required fields.
106
+ """
107
+ return (
108
+ isinstance(tool_call, dict) and
109
+ 'name' in tool_call and
110
+ isinstance(tool_call['name'], str)
111
+ )
112
+
113
+ def extract_function_call(self, response_parts: List[Any]) -> Dict[str, Any]:
114
+ """
115
+ Extract function call details from the response parts.
116
+
117
+ Args:
118
+ response_parts (list): The list of response parts from the chat model.
119
+
120
+ Returns:
121
+ dict: A dictionary containing the function name and flattened arguments.
122
+ """
123
+ for part in response_parts:
124
+ # Debug print
125
+ print(f"Examining part: {part}")
126
+ print(f"Part type: {type(part)}")
127
+
128
+ # Check for function_call attribute
129
+ if hasattr(part, 'function_call') and part.function_call:
130
+ function_call = part.function_call
131
+
132
+ # Debug print
133
+ print(f"Function call: {function_call}")
134
+ print(f"Function call type: {type(function_call)}")
135
+ print(f"Function args: {function_call.args}")
136
+
137
+ # Extract function name
138
+ function_name = getattr(function_call, 'name', None)
139
+ if not function_name:
140
+ continue # Skip if function name is missing
141
+
142
+ # Extract function arguments
143
+ function_args = self._extract_function_args(function_call.args)
144
+
145
+ return {
146
+ "name": function_name,
147
+ "args": function_args
148
+ }
149
+ return {}
document_logs_2024-12-20.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 2024-12-20 12:49:01,713 - INFO - ID: ea205193-4582-44bc-ab71-80176aac7aef, Snippet: [SSW](https://sswalfa.surabaya.go.id/home) [![](https://sswalfa.surabaya.go.id/assets/images/logo-
testcode.py CHANGED
@@ -1,4 +1,140 @@
 
 
 
 
 
 
 
 
 
1
  import PIL.Image
2
 
3
- organ = PIL.Image.open("organ.jpg")
4
- print(organ)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import httpx
3
+ from dotenv import load_dotenv
4
+ from typing import Dict, Any, Optional, List, Iterable
5
+ from datetime import datetime
6
+ import logging
7
+ import asyncio
8
+ import json
9
+ import google.generativeai as genai
10
  import PIL.Image
11
 
12
+ # Import custom modules
13
+ from app.utils.load_env import ACCESS_TOKEN, WHATSAPP_API_URL, GEMNI_API, OPENAI_API
14
+ from app.utils.system_prompt import system_prompt, agentic_prompt
15
+ from google.generativeai.types import content_types
16
+ from testtool import ToolCallParser, FunctionExecutor
17
+ from app.services.search_engine import google_search, set_light_values
18
+
19
+ # Load environment variables
20
+ load_dotenv()
21
+
22
+ # Set up logging
23
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
24
+ logger = logging.getLogger(__name__)
25
+
26
+ def tool_config_from_mode(mode: str, fns: Iterable[str] = ()):
27
+ """
28
+ Create a tool config with the specified function calling mode.
29
+ """
30
+ return content_types.to_tool_config(
31
+ {"function_calling_config": {"mode": mode, "allowed_function_names": fns}}
32
+ )
33
+
34
+ def transform_result_to_response(results: List[Dict[str, Any]]) -> Dict[str, Any]:
35
+ """
36
+ Transform a list of result objects into a structured response dictionary.
37
+ """
38
+ response = {}
39
+ for res in results:
40
+ if res.get("status") == "success":
41
+ function_name = res.get("function")
42
+ function_result = res.get("result")
43
+ response[function_name] = function_result
44
+ else:
45
+ # Handle individual failures if necessary
46
+ response[res.get("function", "unknown_function")] = {
47
+ "error": "Function execution failed."
48
+ }
49
+ return response
50
+
51
+ async def process_tool_calls(input_string: str) -> List[Dict[str, Any]]:
52
+ """
53
+ Processes all tool calls extracted from the input string and executes them.
54
+ """
55
+ tool_calls = ToolCallParser.extract_tool_calls(input_string)
56
+ logger.info(f"Extracted tool_calls: {tool_calls}")
57
+ results = []
58
+ for tool_call in tool_calls:
59
+ result = await FunctionExecutor.call_function(tool_call)
60
+ results.append(result)
61
+ return results
62
+
63
+ async def main():
64
+ # Define available functions and tool configuration
65
+ available_functions = ["google_search", "set_light_values"]
66
+ config = tool_config_from_mode("any", fns=available_functions)
67
+
68
+ # Define chat history
69
+ history = [{"role": "user", "parts": "This is the chat history so far"}]
70
+
71
+ # Configure the Gemini API
72
+ genai.configure(api_key=GEMNI_API)
73
+ model = genai.GenerativeModel(
74
+ "gemini-1.5-pro-002",
75
+ system_instruction=agentic_prompt,
76
+ tools=[google_search, set_light_values]
77
+ )
78
+
79
+ # Start chat with history
80
+ chat = model.start_chat(history=history)
81
+
82
+ # Send the user's message and await the response
83
+ try:
84
+ response = chat.send_message(
85
+ "find the cheapest flight price from Medan to Jakarta on 1st January 2025",
86
+ tool_config=config
87
+ )
88
+ except Exception as e:
89
+ logger.error(f"Error sending message: {e}")
90
+ return
91
+
92
+ # Ensure that response.parts exists and is iterable
93
+ if not hasattr(response, 'parts') or not isinstance(response.parts, Iterable):
94
+ logger.error("Invalid response format: 'parts' attribute is missing or not iterable.")
95
+ return
96
+
97
+ # Convert response parts to a single input string
98
+ input_string = "\n".join(str(part) for part in response.parts)
99
+ logger.info(f"Input string for tool processing: {input_string}")
100
+
101
+ # Process tool calls
102
+ try:
103
+ results = await process_tool_calls(input_string)
104
+ except Exception as e:
105
+ logger.error(f"Error processing tool calls: {e}")
106
+ return
107
+
108
+ # Log and print the results
109
+ logger.info("Results from tool calls:")
110
+ for result in results:
111
+ logger.info(json.dumps(result, indent=4))
112
+ print(json.dumps(result, indent=4))
113
+
114
+ # Transform the results into the desired response format
115
+ responses = transform_result_to_response(results)
116
+
117
+ # Build the response parts for the chat
118
+ try:
119
+ response_parts = [
120
+ genai.protos.Part(
121
+ function_response=genai.protos.FunctionResponse(
122
+ name=fn,
123
+ response={"result": val}
124
+ )
125
+ )
126
+ for fn, val in responses.items()
127
+ ]
128
+ except Exception as e:
129
+ logger.error(f"Error building response parts: {e}")
130
+ return
131
+
132
+ # Send the function responses back to the chat
133
+ try:
134
+ final_response = chat.send_message(response_parts)
135
+ print(final_response.text)
136
+ except Exception as e:
137
+ logger.error(f"Error sending final response: {e}")
138
+
139
+ if __name__ == "__main__":
140
+ asyncio.run(main())