Spaces:
Sleeping
Sleeping
AdrienB134
commited on
Commit
•
7fdb8e9
1
Parent(s):
9ae6b81
Upload 54 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- Dockerfile +32 -0
- README.md +0 -8
- app.py +81 -0
- pyproject.toml +18 -0
- rag_demo/__init__.py +3 -0
- rag_demo/__pycache__/__init__.cpython-311.pyc +0 -0
- rag_demo/__pycache__/pipeline.cpython-311.pyc +0 -0
- rag_demo/__pycache__/settings.cpython-311.pyc +0 -0
- rag_demo/data/test.pdf +0 -0
- rag_demo/data/test2.pdf +3 -0
- rag_demo/infra/__pycache__/qdrant.cpython-311.pyc +0 -0
- rag_demo/infra/qdrant.py +25 -0
- rag_demo/pipeline.py +13 -0
- rag_demo/preprocessing/__init__.py +5 -0
- rag_demo/preprocessing/__pycache__/__init__.cpython-311.pyc +0 -0
- rag_demo/preprocessing/__pycache__/chunking.cpython-311.pyc +0 -0
- rag_demo/preprocessing/__pycache__/embed.cpython-311.pyc +0 -0
- rag_demo/preprocessing/__pycache__/load_to_vectordb.cpython-311.pyc +0 -0
- rag_demo/preprocessing/__pycache__/pdf_conversion.cpython-311.pyc +0 -0
- rag_demo/preprocessing/base/__init__.py +12 -0
- rag_demo/preprocessing/base/__pycache__/__init__.cpython-311.pyc +0 -0
- rag_demo/preprocessing/base/__pycache__/chunk.cpython-311.pyc +0 -0
- rag_demo/preprocessing/base/__pycache__/document.cpython-311.pyc +0 -0
- rag_demo/preprocessing/base/__pycache__/embedded_chunk.cpython-311.pyc +0 -0
- rag_demo/preprocessing/base/__pycache__/vectordb.cpython-311.pyc +0 -0
- rag_demo/preprocessing/base/chunk.py +13 -0
- rag_demo/preprocessing/base/document.py +19 -0
- rag_demo/preprocessing/base/embedded_chunk.py +34 -0
- rag_demo/preprocessing/base/embeddings.py +145 -0
- rag_demo/preprocessing/base/vectordb.py +289 -0
- rag_demo/preprocessing/chunking.py +26 -0
- rag_demo/preprocessing/embed.py +57 -0
- rag_demo/preprocessing/load_to_vectordb.py +30 -0
- rag_demo/preprocessing/pdf_conversion.py +33 -0
- rag_demo/rag/__pycache__/prompt_templates.cpython-311.pyc +0 -0
- rag_demo/rag/__pycache__/query_expansion.cpython-311.pyc +0 -0
- rag_demo/rag/__pycache__/reranker.cpython-311.pyc +0 -0
- rag_demo/rag/__pycache__/retriever.cpython-311.pyc +0 -0
- rag_demo/rag/base/__init__.py +3 -0
- rag_demo/rag/base/__pycache__/__init__.cpython-311.pyc +0 -0
- rag_demo/rag/base/__pycache__/query.cpython-311.pyc +0 -0
- rag_demo/rag/base/__pycache__/template_factory.cpython-311.pyc +0 -0
- rag_demo/rag/base/base.py +22 -0
- rag_demo/rag/base/query.py +29 -0
- rag_demo/rag/base/template_factory.py +22 -0
- rag_demo/rag/prompt_templates.py +38 -0
- rag_demo/rag/query_expansion.py +39 -0
- rag_demo/rag/reranker.py +24 -0
- rag_demo/rag/retriever.py +133 -0
.gitattributes
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
rag_demo/data/test2.pdf filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use Python 3.10 as base image
|
2 |
+
FROM python:3.10-slim
|
3 |
+
|
4 |
+
# Set working directory
|
5 |
+
WORKDIR /app
|
6 |
+
|
7 |
+
# Install system dependencies and uv
|
8 |
+
RUN apt-get update && apt-get install -y \
|
9 |
+
poppler-utils \
|
10 |
+
curl \
|
11 |
+
&& rm -rf /var/lib/apt/lists/* \
|
12 |
+
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
13 |
+
&& echo 'export PATH="$HOME/.local/bin:$PATH"' >> ~/.bashrc \
|
14 |
+
&& . ~/.bashrc
|
15 |
+
|
16 |
+
# Copy requirements first to leverage Docker cache
|
17 |
+
COPY pyproject.toml .
|
18 |
+
|
19 |
+
# Install Python dependencies using uv
|
20 |
+
RUN . ~/.bashrc && uv pip install -r pyproject.toml --system
|
21 |
+
|
22 |
+
# Copy the rest of the application
|
23 |
+
COPY . .
|
24 |
+
|
25 |
+
# Create directories for uploads and embeddings if they don't exist
|
26 |
+
RUN mkdir -p uploads embeddings
|
27 |
+
|
28 |
+
# Expose the port the app runs on
|
29 |
+
EXPOSE 7860
|
30 |
+
|
31 |
+
# Change to rag_demo directory and run the app
|
32 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
@@ -1,8 +0,0 @@
|
|
1 |
-
---
|
2 |
-
title: matriv-rag-demo
|
3 |
-
colorFrom: blue
|
4 |
-
colorTo: red
|
5 |
-
sdk: docker
|
6 |
-
app_file: app.py
|
7 |
-
pinned: false
|
8 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, File, UploadFile, Request
|
2 |
+
from fastapi.templating import Jinja2Templates
|
3 |
+
from fastapi.responses import HTMLResponse
|
4 |
+
from fastapi.staticfiles import StaticFiles
|
5 |
+
from pydantic import BaseModel
|
6 |
+
import os
|
7 |
+
from rag_demo.pipeline import process_pdf
|
8 |
+
import nest_asyncio
|
9 |
+
from rag_demo.rag.retriever import RAGPipeline
|
10 |
+
from loguru import logger
|
11 |
+
|
12 |
+
app = FastAPI()
|
13 |
+
|
14 |
+
# Apply nest_asyncio at the start of the application
|
15 |
+
nest_asyncio.apply()
|
16 |
+
|
17 |
+
# Create templates directory if it doesn't exist
|
18 |
+
templates = Jinja2Templates(directory="templates")
|
19 |
+
|
20 |
+
app.mount("/static", StaticFiles(directory="static"), name="static")
|
21 |
+
|
22 |
+
|
23 |
+
class ChatRequest(BaseModel):
|
24 |
+
question: str
|
25 |
+
|
26 |
+
|
27 |
+
@app.get("/", response_class=HTMLResponse)
|
28 |
+
async def upload_page(request: Request):
|
29 |
+
return templates.TemplateResponse("upload.html", {"request": request})
|
30 |
+
|
31 |
+
|
32 |
+
@app.get("/chat", response_class=HTMLResponse)
|
33 |
+
async def chat_page(request: Request):
|
34 |
+
return templates.TemplateResponse("chat.html", {"request": request})
|
35 |
+
|
36 |
+
|
37 |
+
@app.post("/upload")
|
38 |
+
async def upload_pdf(request: Request, file: UploadFile = File(...)):
|
39 |
+
try:
|
40 |
+
# Create uploads directory if it doesn't exist
|
41 |
+
os.makedirs("data", exist_ok=True)
|
42 |
+
|
43 |
+
file_path = f"data/{file.filename}"
|
44 |
+
with open(file_path, "wb") as buffer:
|
45 |
+
content = await file.read()
|
46 |
+
buffer.write(content)
|
47 |
+
|
48 |
+
# Process the PDF file with proper await statements
|
49 |
+
await process_pdf(file_path)
|
50 |
+
|
51 |
+
# Return template response with success message
|
52 |
+
return templates.TemplateResponse(
|
53 |
+
"upload.html",
|
54 |
+
{
|
55 |
+
"request": request,
|
56 |
+
"message": f"Successfully processed {file.filename}",
|
57 |
+
"processing": False,
|
58 |
+
},
|
59 |
+
)
|
60 |
+
except Exception as e:
|
61 |
+
return templates.TemplateResponse(
|
62 |
+
"upload.html", {"request": request, "error": str(e), "processing": False}
|
63 |
+
)
|
64 |
+
|
65 |
+
|
66 |
+
@app.post("/chat")
|
67 |
+
async def chat(chat_request: ChatRequest):
|
68 |
+
rag_pipeline = RAGPipeline()
|
69 |
+
try:
|
70 |
+
answer = rag_pipeline.rag(chat_request.question)
|
71 |
+
print(answer)
|
72 |
+
logger.info(answer)
|
73 |
+
return {"answer": answer}
|
74 |
+
except Exception as e:
|
75 |
+
return {"error": str(e)}
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
import uvicorn
|
80 |
+
|
81 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
pyproject.toml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "rag-base"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "Add your description here"
|
5 |
+
readme = "README.md"
|
6 |
+
requires-python = ">=3.11"
|
7 |
+
dependencies = [
|
8 |
+
"loguru>=0.7.2",
|
9 |
+
"langchain>=0.3.9",
|
10 |
+
"marker-pdf>=1.0.2",
|
11 |
+
"qdrant-client[fastembed]>=1.12.1",
|
12 |
+
"fastapi>=0.115.6",
|
13 |
+
"pydantic>=2.10.3",
|
14 |
+
"python-multipart>=0.0.19",
|
15 |
+
"uvicorn>=0.32.1",
|
16 |
+
"huggingface-hub>=0.26.3",
|
17 |
+
"llama-parse>=0.5.17",
|
18 |
+
]
|
rag_demo/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .infra.qdrant import connection
|
2 |
+
|
3 |
+
__all__ = ["connection"]
|
rag_demo/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (263 Bytes). View file
|
|
rag_demo/__pycache__/pipeline.cpython-311.pyc
ADDED
Binary file (738 Bytes). View file
|
|
rag_demo/__pycache__/settings.cpython-311.pyc
ADDED
Binary file (1.99 kB). View file
|
|
rag_demo/data/test.pdf
ADDED
Binary file (344 kB). View file
|
|
rag_demo/data/test2.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b3041eb7dd274b02a2f18049891dc3f184dff4151796f225b92cd34d676ba923
|
3 |
+
size 1962780
|
rag_demo/infra/__pycache__/qdrant.cpython-311.pyc
ADDED
Binary file (1.32 kB). View file
|
|
rag_demo/infra/qdrant.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from loguru import logger
|
2 |
+
from qdrant_client import QdrantClient
|
3 |
+
from qdrant_client.http.exceptions import UnexpectedResponse
|
4 |
+
|
5 |
+
|
6 |
+
class QdrantDatabaseConnector:
|
7 |
+
_instance: QdrantClient | None = None
|
8 |
+
|
9 |
+
def __new__(cls, *args, **kwargs) -> QdrantClient:
|
10 |
+
if cls._instance is None:
|
11 |
+
try:
|
12 |
+
cls._instance = QdrantClient(":memory:")
|
13 |
+
|
14 |
+
logger.info(f"Connection to Qdrant DB with URI successful")
|
15 |
+
except:
|
16 |
+
logger.exception(
|
17 |
+
"Couldn't connect to Qdrant.",
|
18 |
+
)
|
19 |
+
|
20 |
+
raise
|
21 |
+
|
22 |
+
return cls._instance
|
23 |
+
|
24 |
+
|
25 |
+
connection = QdrantDatabaseConnector()
|
rag_demo/pipeline.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rag_demo.preprocessing import (
|
2 |
+
convert_pdf_to_text,
|
3 |
+
load_to_vector_db,
|
4 |
+
chunk_and_embed,
|
5 |
+
)
|
6 |
+
from loguru import logger
|
7 |
+
|
8 |
+
|
9 |
+
def process_pdf(file_path: str):
|
10 |
+
convert = convert_pdf_to_text([file_path])
|
11 |
+
embedded_chunks = chunk_and_embed([convert])
|
12 |
+
load_to_vector_db(embedded_chunks)
|
13 |
+
return True
|
rag_demo/preprocessing/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .pdf_conversion import convert_pdf_to_text
|
2 |
+
from .load_to_vectordb import load_to_vector_db
|
3 |
+
from .embed import chunk_and_embed
|
4 |
+
|
5 |
+
__all__ = ["convert_pdf_to_text", "load_to_vector_db", "chunk_and_embed"]
|
rag_demo/preprocessing/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (441 Bytes). View file
|
|
rag_demo/preprocessing/__pycache__/chunking.cpython-311.pyc
ADDED
Binary file (1.25 kB). View file
|
|
rag_demo/preprocessing/__pycache__/embed.cpython-311.pyc
ADDED
Binary file (3.53 kB). View file
|
|
rag_demo/preprocessing/__pycache__/load_to_vectordb.cpython-311.pyc
ADDED
Binary file (2.39 kB). View file
|
|
rag_demo/preprocessing/__pycache__/pdf_conversion.cpython-311.pyc
ADDED
Binary file (1.8 kB). View file
|
|
rag_demo/preprocessing/base/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .document import Document, CleanedDocument
|
2 |
+
from .chunk import Chunk
|
3 |
+
from .embedded_chunk import EmbeddedChunk
|
4 |
+
from .vectordb import VectorBaseDocument
|
5 |
+
|
6 |
+
__all__ = [
|
7 |
+
"Document",
|
8 |
+
"CleanedDocument",
|
9 |
+
"Chunk",
|
10 |
+
"EmbeddedChunk",
|
11 |
+
"VectorBaseDocument",
|
12 |
+
]
|
rag_demo/preprocessing/base/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (528 Bytes). View file
|
|
rag_demo/preprocessing/base/__pycache__/chunk.cpython-311.pyc
ADDED
Binary file (927 Bytes). View file
|
|
rag_demo/preprocessing/base/__pycache__/document.cpython-311.pyc
ADDED
Binary file (1.12 kB). View file
|
|
rag_demo/preprocessing/base/__pycache__/embedded_chunk.cpython-311.pyc
ADDED
Binary file (2.04 kB). View file
|
|
rag_demo/preprocessing/base/__pycache__/vectordb.cpython-311.pyc
ADDED
Binary file (16.7 kB). View file
|
|
rag_demo/preprocessing/base/chunk.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
from pydantic import UUID4, Field
|
5 |
+
|
6 |
+
from rag_demo.preprocessing.base.vectordb import VectorBaseDocument
|
7 |
+
|
8 |
+
|
9 |
+
class Chunk(VectorBaseDocument, ABC):
|
10 |
+
content: str
|
11 |
+
document_id: UUID4
|
12 |
+
chunk_id: UUID4
|
13 |
+
metadata: dict = Field(default_factory=dict)
|
rag_demo/preprocessing/base/document.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
from pydantic import UUID4, BaseModel
|
5 |
+
|
6 |
+
from .vectordb import VectorBaseDocument
|
7 |
+
|
8 |
+
|
9 |
+
class CleanedDocument(VectorBaseDocument, ABC):
|
10 |
+
content: str
|
11 |
+
doc_id: UUID4
|
12 |
+
doc_title: str
|
13 |
+
# doc_url: str
|
14 |
+
|
15 |
+
|
16 |
+
class Document(BaseModel):
|
17 |
+
text: str
|
18 |
+
document_id: UUID4
|
19 |
+
metadata: dict
|
rag_demo/preprocessing/base/embedded_chunk.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
|
3 |
+
from pydantic import UUID4, Field
|
4 |
+
|
5 |
+
|
6 |
+
from .vectordb import VectorBaseDocument
|
7 |
+
|
8 |
+
|
9 |
+
class EmbeddedChunk(VectorBaseDocument, ABC):
|
10 |
+
content: str
|
11 |
+
embedding: list[float] | None
|
12 |
+
document_id: UUID4
|
13 |
+
chunk_id: UUID4
|
14 |
+
metadata: dict = Field(default_factory=dict)
|
15 |
+
similarity: float | None
|
16 |
+
|
17 |
+
@classmethod
|
18 |
+
def to_context(cls, chunks: list["EmbeddedChunk"]) -> str:
|
19 |
+
context = ""
|
20 |
+
for i, chunk in enumerate(chunks):
|
21 |
+
context += f"""
|
22 |
+
Chunk {i + 1}:
|
23 |
+
Type: {chunk.__class__.__name__}
|
24 |
+
Document ID: {chunk.document_id}
|
25 |
+
Chunk ID: {chunk.chunk_id}
|
26 |
+
Content: {chunk.content}\n
|
27 |
+
"""
|
28 |
+
|
29 |
+
return context
|
30 |
+
|
31 |
+
class Config:
|
32 |
+
name = "embedded_documents"
|
33 |
+
category = "Document"
|
34 |
+
use_vector_index = True
|
rag_demo/preprocessing/base/embeddings.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import cached_property
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Optional, ClassVar
|
4 |
+
from threading import Lock
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from loguru import logger
|
8 |
+
from numpy.typing import NDArray
|
9 |
+
from sentence_transformers.SentenceTransformer import SentenceTransformer
|
10 |
+
from transformers import AutoTokenizer
|
11 |
+
|
12 |
+
from rag_demo.settings import settings
|
13 |
+
|
14 |
+
|
15 |
+
class SingletonMeta(type):
|
16 |
+
"""
|
17 |
+
This is a thread-safe implementation of Singleton.
|
18 |
+
"""
|
19 |
+
|
20 |
+
_instances: ClassVar = {}
|
21 |
+
|
22 |
+
_lock: Lock = Lock()
|
23 |
+
|
24 |
+
"""
|
25 |
+
We now have a lock object that will be used to synchronize threads during
|
26 |
+
first access to the Singleton.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __call__(cls, *args, **kwargs):
|
30 |
+
"""
|
31 |
+
Possible changes to the value of the `__init__` argument do not affect
|
32 |
+
the returned instance.
|
33 |
+
"""
|
34 |
+
# Now, imagine that the program has just been launched. Since there's no
|
35 |
+
# Singleton instance yet, multiple threads can simultaneously pass the
|
36 |
+
# previous conditional and reach this point almost at the same time. The
|
37 |
+
# first of them will acquire lock and will proceed further, while the
|
38 |
+
# rest will wait here.
|
39 |
+
with cls._lock:
|
40 |
+
# The first thread to acquire the lock, reaches this conditional,
|
41 |
+
# goes inside and creates the Singleton instance. Once it leaves the
|
42 |
+
# lock block, a thread that might have been waiting for the lock
|
43 |
+
# release may then enter this section. But since the Singleton field
|
44 |
+
# is already initialized, the thread won't create a new object.
|
45 |
+
if cls not in cls._instances:
|
46 |
+
instance = super().__call__(*args, **kwargs)
|
47 |
+
cls._instances[cls] = instance
|
48 |
+
|
49 |
+
return cls._instances[cls]
|
50 |
+
|
51 |
+
|
52 |
+
class EmbeddingModelSingleton(metaclass=SingletonMeta):
|
53 |
+
"""
|
54 |
+
A singleton class that provides a pre-trained transformer model for generating embeddings of input text.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
model_id: str = settings.TEXT_EMBEDDING_MODEL_ID,
|
60 |
+
device: str = settings.RAG_MODEL_DEVICE,
|
61 |
+
cache_dir: Optional[Path] = None,
|
62 |
+
) -> None:
|
63 |
+
self._model_id = model_id
|
64 |
+
self._device = device
|
65 |
+
|
66 |
+
self._model = SentenceTransformer(
|
67 |
+
self._model_id,
|
68 |
+
device=self._device,
|
69 |
+
cache_folder=str(cache_dir) if cache_dir else None,
|
70 |
+
)
|
71 |
+
self._model.eval()
|
72 |
+
|
73 |
+
@property
|
74 |
+
def model_id(self) -> str:
|
75 |
+
"""
|
76 |
+
Returns the identifier of the pre-trained transformer model to use.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
str: The identifier of the pre-trained transformer model to use.
|
80 |
+
"""
|
81 |
+
|
82 |
+
return self._model_id
|
83 |
+
|
84 |
+
@cached_property
|
85 |
+
def embedding_size(self) -> int:
|
86 |
+
"""
|
87 |
+
Returns the size of the embeddings generated by the pre-trained transformer model.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
int: The size of the embeddings generated by the pre-trained transformer model.
|
91 |
+
"""
|
92 |
+
|
93 |
+
dummy_embedding = self._model.encode("")
|
94 |
+
|
95 |
+
return dummy_embedding.shape[0]
|
96 |
+
|
97 |
+
@property
|
98 |
+
def max_input_length(self) -> int:
|
99 |
+
"""
|
100 |
+
Returns the maximum length of input text to tokenize.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
int: The maximum length of input text to tokenize.
|
104 |
+
"""
|
105 |
+
|
106 |
+
return self._model.max_seq_length
|
107 |
+
|
108 |
+
@property
|
109 |
+
def tokenizer(self) -> AutoTokenizer:
|
110 |
+
"""
|
111 |
+
Returns the tokenizer used to tokenize input text.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
AutoTokenizer: The tokenizer used to tokenize input text.
|
115 |
+
"""
|
116 |
+
|
117 |
+
return self._model.tokenizer
|
118 |
+
|
119 |
+
def __call__(
|
120 |
+
self, input_text: str | list[str], to_list: bool = True
|
121 |
+
) -> NDArray[np.float32] | list[float] | list[list[float]]:
|
122 |
+
"""
|
123 |
+
Generates embeddings for the input text using the pre-trained transformer model.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
input_text (str): The input text to generate embeddings for.
|
127 |
+
to_list (bool): Whether to return the embeddings as a list or numpy array. Defaults to True.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
Union[np.ndarray, list]: The embeddings generated for the input text.
|
131 |
+
"""
|
132 |
+
|
133 |
+
try:
|
134 |
+
embeddings = self._model.encode(input_text)
|
135 |
+
except Exception:
|
136 |
+
logger.error(
|
137 |
+
f"Error generating embeddings for {self._model_id=} and {input_text=}"
|
138 |
+
)
|
139 |
+
|
140 |
+
return [] if to_list else np.array([])
|
141 |
+
|
142 |
+
if to_list:
|
143 |
+
embeddings = embeddings.tolist()
|
144 |
+
|
145 |
+
return embeddings
|
rag_demo/preprocessing/base/vectordb.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uuid
|
2 |
+
from abc import ABC
|
3 |
+
from typing import Any, Callable, Dict, Generic, Type, TypeVar
|
4 |
+
from uuid import UUID
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from loguru import logger
|
8 |
+
from pydantic import UUID4, BaseModel, Field
|
9 |
+
from qdrant_client.http import exceptions
|
10 |
+
from qdrant_client.http.models import Distance, VectorParams
|
11 |
+
from qdrant_client.models import CollectionInfo, PointStruct, Record
|
12 |
+
|
13 |
+
|
14 |
+
from rag_demo.infra.qdrant import connection
|
15 |
+
|
16 |
+
T = TypeVar("T", bound="VectorBaseDocument")
|
17 |
+
|
18 |
+
EMBEDDING_SIZE = 1024
|
19 |
+
|
20 |
+
|
21 |
+
class VectorBaseDocument(BaseModel, Generic[T], ABC):
|
22 |
+
id: UUID4 = Field(default_factory=uuid.uuid4)
|
23 |
+
|
24 |
+
def __eq__(self, value: object) -> bool:
|
25 |
+
if not isinstance(value, self.__class__):
|
26 |
+
return False
|
27 |
+
|
28 |
+
return self.id == value.id
|
29 |
+
|
30 |
+
def __hash__(self) -> int:
|
31 |
+
return hash(self.id)
|
32 |
+
|
33 |
+
@classmethod
|
34 |
+
def from_record(cls: Type[T], point: Record) -> T:
|
35 |
+
_id = UUID(point.id, version=4)
|
36 |
+
payload = point.payload or {}
|
37 |
+
|
38 |
+
attributes = {
|
39 |
+
"id": _id,
|
40 |
+
**payload,
|
41 |
+
}
|
42 |
+
if cls._has_class_attribute("embedding"):
|
43 |
+
attributes["embedding"] = point.vector or None
|
44 |
+
|
45 |
+
return cls(**attributes)
|
46 |
+
|
47 |
+
def to_point(self: T, **kwargs) -> PointStruct:
|
48 |
+
exclude_unset = kwargs.pop("exclude_unset", False)
|
49 |
+
by_alias = kwargs.pop("by_alias", True)
|
50 |
+
|
51 |
+
payload = self.model_dump(
|
52 |
+
exclude_unset=exclude_unset, by_alias=by_alias, **kwargs
|
53 |
+
)
|
54 |
+
|
55 |
+
_id = str(payload.pop("id"))
|
56 |
+
vector = payload.pop("embedding", {})
|
57 |
+
if vector and isinstance(vector, np.ndarray):
|
58 |
+
vector = vector.tolist()
|
59 |
+
|
60 |
+
return PointStruct(id=_id, vector=vector, payload=payload)
|
61 |
+
|
62 |
+
def model_dump(self: T, **kwargs) -> dict:
|
63 |
+
dict_ = super().model_dump(**kwargs)
|
64 |
+
|
65 |
+
dict_ = self._uuid_to_str(dict_)
|
66 |
+
|
67 |
+
return dict_
|
68 |
+
|
69 |
+
def _uuid_to_str(self, item: Any) -> Any:
|
70 |
+
if isinstance(item, dict):
|
71 |
+
for key, value in item.items():
|
72 |
+
if isinstance(value, UUID):
|
73 |
+
item[key] = str(value)
|
74 |
+
elif isinstance(value, list):
|
75 |
+
item[key] = [self._uuid_to_str(v) for v in value]
|
76 |
+
elif isinstance(value, dict):
|
77 |
+
item[key] = {k: self._uuid_to_str(v) for k, v in value.items()}
|
78 |
+
|
79 |
+
return item
|
80 |
+
|
81 |
+
@classmethod
|
82 |
+
def bulk_insert(cls: Type[T], documents: list["VectorBaseDocument"]) -> bool:
|
83 |
+
try:
|
84 |
+
cls._bulk_insert(documents)
|
85 |
+
logger.info(
|
86 |
+
f"Successfully inserted {len(documents)} documents into {cls.get_collection_name()}"
|
87 |
+
)
|
88 |
+
|
89 |
+
except Exception as e:
|
90 |
+
logger.error(f"Error inserting documents: {e}")
|
91 |
+
logger.info(
|
92 |
+
f"Collection '{cls.get_collection_name()}' does not exist. Trying to create the collection and reinsert the documents."
|
93 |
+
)
|
94 |
+
|
95 |
+
cls.create_collection()
|
96 |
+
|
97 |
+
try:
|
98 |
+
cls._bulk_insert(documents)
|
99 |
+
except Exception as e:
|
100 |
+
logger.error(f"Error inserting documents: {e}")
|
101 |
+
logger.error(
|
102 |
+
f"Failed to insert documents in '{cls.get_collection_name()}'."
|
103 |
+
)
|
104 |
+
|
105 |
+
return False
|
106 |
+
|
107 |
+
return True
|
108 |
+
|
109 |
+
@classmethod
|
110 |
+
def _bulk_insert(cls: Type[T], documents: list["VectorBaseDocument"]) -> None:
|
111 |
+
points = [doc.to_point() for doc in documents]
|
112 |
+
|
113 |
+
connection.upsert(collection_name=cls.get_collection_name(), points=points)
|
114 |
+
|
115 |
+
@classmethod
|
116 |
+
def bulk_find(
|
117 |
+
cls: Type[T], limit: int = 10, **kwargs
|
118 |
+
) -> tuple[list[T], UUID | None]:
|
119 |
+
try:
|
120 |
+
documents, next_offset = cls._bulk_find(limit=limit, **kwargs)
|
121 |
+
except exceptions.UnexpectedResponse:
|
122 |
+
logger.error(
|
123 |
+
f"Failed to search documents in '{cls.get_collection_name()}'."
|
124 |
+
)
|
125 |
+
|
126 |
+
documents, next_offset = [], None
|
127 |
+
|
128 |
+
return documents, next_offset
|
129 |
+
|
130 |
+
@classmethod
|
131 |
+
def _bulk_find(
|
132 |
+
cls: Type[T], limit: int = 10, **kwargs
|
133 |
+
) -> tuple[list[T], UUID | None]:
|
134 |
+
collection_name = cls.get_collection_name()
|
135 |
+
|
136 |
+
offset = kwargs.pop("offset", None)
|
137 |
+
offset = str(offset) if offset else None
|
138 |
+
|
139 |
+
records, next_offset = connection.scroll(
|
140 |
+
collection_name=collection_name,
|
141 |
+
limit=limit,
|
142 |
+
with_payload=kwargs.pop("with_payload", True),
|
143 |
+
with_vectors=kwargs.pop("with_vectors", False),
|
144 |
+
offset=offset,
|
145 |
+
**kwargs,
|
146 |
+
)
|
147 |
+
documents = [cls.from_record(record) for record in records]
|
148 |
+
if next_offset is not None:
|
149 |
+
next_offset = UUID(next_offset, version=4)
|
150 |
+
|
151 |
+
return documents, next_offset
|
152 |
+
|
153 |
+
@classmethod
|
154 |
+
def search(cls: Type[T], query_vector: list, limit: int = 10, **kwargs) -> list[T]:
|
155 |
+
try:
|
156 |
+
documents = cls._search(query_vector=query_vector, limit=limit, **kwargs)
|
157 |
+
except exceptions.UnexpectedResponse:
|
158 |
+
logger.error(
|
159 |
+
f"Failed to search documents in '{cls.get_collection_name()}'."
|
160 |
+
)
|
161 |
+
|
162 |
+
documents = []
|
163 |
+
|
164 |
+
return documents
|
165 |
+
|
166 |
+
@classmethod
|
167 |
+
def _search(cls: Type[T], query_vector: list, limit: int = 10, **kwargs) -> list[T]:
|
168 |
+
collection_name = cls.get_collection_name()
|
169 |
+
records = connection.search(
|
170 |
+
collection_name=collection_name,
|
171 |
+
query_vector=query_vector,
|
172 |
+
limit=limit,
|
173 |
+
with_payload=kwargs.pop("with_payload", True),
|
174 |
+
with_vectors=kwargs.pop("with_vectors", False),
|
175 |
+
**kwargs,
|
176 |
+
)
|
177 |
+
documents = [cls.from_record(record) for record in records]
|
178 |
+
|
179 |
+
return documents
|
180 |
+
|
181 |
+
@classmethod
|
182 |
+
def get_or_create_collection(cls: Type[T]) -> CollectionInfo:
|
183 |
+
collection_name = cls.get_collection_name()
|
184 |
+
|
185 |
+
try:
|
186 |
+
return connection.get_collection(collection_name=collection_name)
|
187 |
+
except exceptions.UnexpectedResponse:
|
188 |
+
use_vector_index = cls.get_use_vector_index()
|
189 |
+
|
190 |
+
collection_created = cls._create_collection(
|
191 |
+
collection_name=collection_name, use_vector_index=use_vector_index
|
192 |
+
)
|
193 |
+
if collection_created is False:
|
194 |
+
raise RuntimeError(
|
195 |
+
f"Couldn't create collection {collection_name}"
|
196 |
+
) from None
|
197 |
+
|
198 |
+
return connection.get_collection(collection_name=collection_name)
|
199 |
+
|
200 |
+
@classmethod
|
201 |
+
def create_collection(cls: Type[T]) -> bool:
|
202 |
+
collection_name = cls.get_collection_name()
|
203 |
+
use_vector_index = cls.get_use_vector_index()
|
204 |
+
logger.info(
|
205 |
+
f"Creating collection {collection_name} with use_vector_index={use_vector_index}"
|
206 |
+
)
|
207 |
+
return cls._create_collection(
|
208 |
+
collection_name=collection_name, use_vector_index=use_vector_index
|
209 |
+
)
|
210 |
+
|
211 |
+
@classmethod
|
212 |
+
def _create_collection(
|
213 |
+
cls, collection_name: str, use_vector_index: bool = True
|
214 |
+
) -> bool:
|
215 |
+
if use_vector_index is True:
|
216 |
+
vectors_config = VectorParams(size=EMBEDDING_SIZE, distance=Distance.COSINE)
|
217 |
+
else:
|
218 |
+
vectors_config = {}
|
219 |
+
|
220 |
+
return connection.create_collection(
|
221 |
+
collection_name=collection_name, vectors_config=vectors_config
|
222 |
+
)
|
223 |
+
|
224 |
+
@classmethod
|
225 |
+
def get_collection_name(cls: Type[T]) -> str:
|
226 |
+
if not hasattr(cls, "Config") or not hasattr(cls.Config, "name"):
|
227 |
+
raise Exception(
|
228 |
+
f"The class {cls} should define a Config class with the 'name' property that reflects the collection's name."
|
229 |
+
)
|
230 |
+
|
231 |
+
return cls.Config.name
|
232 |
+
|
233 |
+
@classmethod
|
234 |
+
def get_use_vector_index(cls: Type[T]) -> bool:
|
235 |
+
if not hasattr(cls, "Config") or not hasattr(cls.Config, "use_vector_index"):
|
236 |
+
return True
|
237 |
+
|
238 |
+
return cls.Config.use_vector_index
|
239 |
+
|
240 |
+
@classmethod
|
241 |
+
def group_by_class(
|
242 |
+
cls: Type["VectorBaseDocument"], documents: list["VectorBaseDocument"]
|
243 |
+
) -> Dict["VectorBaseDocument", list["VectorBaseDocument"]]:
|
244 |
+
return cls._group_by(documents, selector=lambda doc: doc.__class__)
|
245 |
+
|
246 |
+
@classmethod
|
247 |
+
def _group_by(
|
248 |
+
cls: Type[T], documents: list[T], selector: Callable[[T], Any]
|
249 |
+
) -> Dict[Any, list[T]]:
|
250 |
+
grouped = {}
|
251 |
+
for doc in documents:
|
252 |
+
key = selector(doc)
|
253 |
+
|
254 |
+
if key not in grouped:
|
255 |
+
grouped[key] = []
|
256 |
+
grouped[key].append(doc)
|
257 |
+
|
258 |
+
return grouped
|
259 |
+
|
260 |
+
@classmethod
|
261 |
+
def collection_name_to_class(
|
262 |
+
cls: Type["VectorBaseDocument"], collection_name: str
|
263 |
+
) -> type["VectorBaseDocument"]:
|
264 |
+
for subclass in cls.__subclasses__():
|
265 |
+
try:
|
266 |
+
if subclass.get_collection_name() == collection_name:
|
267 |
+
return subclass
|
268 |
+
except Exception:
|
269 |
+
pass
|
270 |
+
|
271 |
+
try:
|
272 |
+
return subclass.collection_name_to_class(collection_name)
|
273 |
+
except ValueError:
|
274 |
+
continue
|
275 |
+
|
276 |
+
raise ValueError(f"No subclass found for collection name: {collection_name}")
|
277 |
+
|
278 |
+
@classmethod
|
279 |
+
def _has_class_attribute(cls: Type[T], attribute_name: str) -> bool:
|
280 |
+
if attribute_name in cls.__annotations__:
|
281 |
+
return True
|
282 |
+
|
283 |
+
for base in cls.__bases__:
|
284 |
+
if hasattr(base, "_has_class_attribute") and base._has_class_attribute(
|
285 |
+
attribute_name
|
286 |
+
):
|
287 |
+
return True
|
288 |
+
|
289 |
+
return False
|
rag_demo/preprocessing/chunking.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from uuid import uuid4
|
2 |
+
|
3 |
+
from langchain.text_splitter import MarkdownTextSplitter
|
4 |
+
from rag_demo.preprocessing.base import Chunk
|
5 |
+
from rag_demo.preprocessing.base import Document
|
6 |
+
|
7 |
+
|
8 |
+
def chunk_text(
|
9 |
+
document: Document, chunk_size: int = 500, chunk_overlap: int = 50
|
10 |
+
) -> list[Chunk]:
|
11 |
+
text_splitter = MarkdownTextSplitter(
|
12 |
+
chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
13 |
+
)
|
14 |
+
chunks = text_splitter.split_text(document.text)
|
15 |
+
result = []
|
16 |
+
for chunk in chunks:
|
17 |
+
result.append(
|
18 |
+
Chunk(
|
19 |
+
content=chunk,
|
20 |
+
document_id=document.document_id,
|
21 |
+
chunk_id=uuid4(),
|
22 |
+
metadata=document.metadata,
|
23 |
+
)
|
24 |
+
)
|
25 |
+
|
26 |
+
return result
|
rag_demo/preprocessing/embed.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing_extensions import Annotated
|
2 |
+
from typing import Generator
|
3 |
+
from .base import Chunk
|
4 |
+
from .base import EmbeddedChunk
|
5 |
+
from .chunking import chunk_text
|
6 |
+
from huggingface_hub import InferenceClient
|
7 |
+
import os
|
8 |
+
from dotenv import load_dotenv
|
9 |
+
from uuid import uuid4
|
10 |
+
from loguru import logger
|
11 |
+
|
12 |
+
load_dotenv()
|
13 |
+
|
14 |
+
|
15 |
+
def batch(list_: list, size: int) -> Generator[list, None, None]:
|
16 |
+
yield from (list_[i : i + size] for i in range(0, len(list_), size))
|
17 |
+
|
18 |
+
|
19 |
+
def embed_chunks(chunks: list[Chunk]) -> list[EmbeddedChunk]:
|
20 |
+
api = InferenceClient(
|
21 |
+
model="intfloat/multilingual-e5-large-instruct",
|
22 |
+
token=os.getenv("HF_API_TOKEN"),
|
23 |
+
)
|
24 |
+
logger.info(f"Embedding {len(chunks)} chunks")
|
25 |
+
embedded_chunks = []
|
26 |
+
for chunk in chunks:
|
27 |
+
try:
|
28 |
+
embedded_chunks.append(
|
29 |
+
EmbeddedChunk(
|
30 |
+
id=uuid4(),
|
31 |
+
content=chunk.content,
|
32 |
+
embedding=api.feature_extraction(chunk.content),
|
33 |
+
document_id=chunk.document_id,
|
34 |
+
chunk_id=chunk.chunk_id,
|
35 |
+
metadata=chunk.metadata,
|
36 |
+
similarity=None,
|
37 |
+
)
|
38 |
+
)
|
39 |
+
except Exception as e:
|
40 |
+
logger.error(f"Error embedding chunk: {e}")
|
41 |
+
logger.info(f"{len(embedded_chunks)} chunks embedded successfully")
|
42 |
+
|
43 |
+
return embedded_chunks
|
44 |
+
|
45 |
+
|
46 |
+
def chunk_and_embed(
|
47 |
+
cleaned_documents: Annotated[list, "cleaned_documents"],
|
48 |
+
) -> Annotated[list, "embedded_documents"]:
|
49 |
+
embedded_chunks = []
|
50 |
+
for document in cleaned_documents:
|
51 |
+
chunks = chunk_text(document)
|
52 |
+
|
53 |
+
for batched_chunks in batch(chunks, 10):
|
54 |
+
batched_embedded_chunks = embed_chunks(batched_chunks)
|
55 |
+
embedded_chunks.extend(batched_embedded_chunks)
|
56 |
+
logger.info(f"{len(embedded_chunks)} chunks embedded successfully")
|
57 |
+
return embedded_chunks
|
rag_demo/preprocessing/load_to_vectordb.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from loguru import logger
|
2 |
+
from typing_extensions import Annotated
|
3 |
+
from typing import Generator
|
4 |
+
|
5 |
+
from .base import VectorBaseDocument
|
6 |
+
|
7 |
+
|
8 |
+
def batch(list_: list, size: int) -> Generator[list, None, None]:
|
9 |
+
yield from (list_[i : i + size] for i in range(0, len(list_), size))
|
10 |
+
|
11 |
+
|
12 |
+
def load_to_vector_db(
|
13 |
+
documents: Annotated[list, "documents"],
|
14 |
+
) -> Annotated[bool, "successful"]:
|
15 |
+
logger.info(f"Loading {len(documents)} documents into the vector database.")
|
16 |
+
|
17 |
+
grouped_documents = VectorBaseDocument.group_by_class(documents)
|
18 |
+
for document_class, documents in grouped_documents.items():
|
19 |
+
logger.info(f"Loading documents into {document_class.get_collection_name()}")
|
20 |
+
for documents_batch in batch(documents, size=4):
|
21 |
+
try:
|
22 |
+
document_class.bulk_insert(documents_batch)
|
23 |
+
except Exception as e:
|
24 |
+
logger.error(
|
25 |
+
f"Failed to insert documents into {document_class.get_collection_name()}: {e}"
|
26 |
+
)
|
27 |
+
|
28 |
+
return False
|
29 |
+
|
30 |
+
return True
|
rag_demo/preprocessing/pdf_conversion.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llama_parse import LlamaParse
|
2 |
+
from llama_index.core import SimpleDirectoryReader
|
3 |
+
from uuid import uuid4
|
4 |
+
from .base import Document
|
5 |
+
from loguru import logger
|
6 |
+
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
|
9 |
+
load_dotenv()
|
10 |
+
|
11 |
+
|
12 |
+
# set up parser
|
13 |
+
parser = LlamaParse(
|
14 |
+
api_key="llx-TN6YSXvZdpG0qhJ7rVx9QFg5Zq298RXr7Id7XzXb5Wr4Rnpt",
|
15 |
+
result_type="markdown", # "markdown" and "text" are available
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
def convert_pdf_to_text(filepaths: list[str]) -> Document:
|
20 |
+
file_extractor = {".pdf": parser}
|
21 |
+
# use SimpleDirectoryReader to parse our file
|
22 |
+
|
23 |
+
documents = SimpleDirectoryReader(
|
24 |
+
input_files=filepaths, file_extractor=file_extractor
|
25 |
+
).load_data()
|
26 |
+
|
27 |
+
logger.info("Converted 1 documents")
|
28 |
+
|
29 |
+
return Document(
|
30 |
+
document_id=uuid4(),
|
31 |
+
text=" ".join(document.text for document in documents),
|
32 |
+
metadata={"filename": filepaths[0].split("/")[-1]},
|
33 |
+
)
|
rag_demo/rag/__pycache__/prompt_templates.cpython-311.pyc
ADDED
Binary file (2.75 kB). View file
|
|
rag_demo/rag/__pycache__/query_expansion.cpython-311.pyc
ADDED
Binary file (2.4 kB). View file
|
|
rag_demo/rag/__pycache__/reranker.cpython-311.pyc
ADDED
Binary file (1.96 kB). View file
|
|
rag_demo/rag/__pycache__/retriever.cpython-311.pyc
ADDED
Binary file (8.21 kB). View file
|
|
rag_demo/rag/base/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .template_factory import PromptTemplateFactory
|
2 |
+
|
3 |
+
__all__ = [PromptTemplateFactory]
|
rag_demo/rag/base/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (283 Bytes). View file
|
|
rag_demo/rag/base/__pycache__/query.cpython-311.pyc
ADDED
Binary file (2.08 kB). View file
|
|
rag_demo/rag/base/__pycache__/template_factory.cpython-311.pyc
ADDED
Binary file (1.64 kB). View file
|
|
rag_demo/rag/base/base.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
from langchain.prompts import PromptTemplate
|
5 |
+
from pydantic import BaseModel
|
6 |
+
|
7 |
+
from rag_demo.rag.base.query import Query
|
8 |
+
|
9 |
+
|
10 |
+
class PromptTemplateFactory(ABC, BaseModel):
|
11 |
+
@abstractmethod
|
12 |
+
def create_template(self) -> PromptTemplate:
|
13 |
+
pass
|
14 |
+
|
15 |
+
|
16 |
+
class RAGStep(ABC):
|
17 |
+
def __init__(self, mock: bool = False) -> None:
|
18 |
+
self._mock = mock
|
19 |
+
|
20 |
+
@abstractmethod
|
21 |
+
def generate(self, query: Query, *args, **kwargs) -> Any:
|
22 |
+
pass
|
rag_demo/rag/base/query.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import UUID4, Field
|
2 |
+
|
3 |
+
from rag_demo.preprocessing.base import VectorBaseDocument
|
4 |
+
|
5 |
+
|
6 |
+
class Query(VectorBaseDocument):
|
7 |
+
content: str
|
8 |
+
metadata: dict = Field(default_factory=dict)
|
9 |
+
|
10 |
+
class Config:
|
11 |
+
category = "query"
|
12 |
+
|
13 |
+
@classmethod
|
14 |
+
def from_str(cls, query: str) -> "Query":
|
15 |
+
return Query(content=query.strip("\n "))
|
16 |
+
|
17 |
+
def replace_content(self, new_content: str) -> "Query":
|
18 |
+
return Query(
|
19 |
+
id=self.id,
|
20 |
+
content=new_content,
|
21 |
+
metadata=self.metadata,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
class EmbeddedQuery(Query):
|
26 |
+
embedding: list[float]
|
27 |
+
|
28 |
+
class Config:
|
29 |
+
category = "query"
|
rag_demo/rag/base/template_factory.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
from langchain.prompts import PromptTemplate
|
5 |
+
from pydantic import BaseModel
|
6 |
+
|
7 |
+
from .query import Query
|
8 |
+
|
9 |
+
|
10 |
+
class PromptTemplateFactory(ABC, BaseModel):
|
11 |
+
@abstractmethod
|
12 |
+
def create_template(self) -> PromptTemplate:
|
13 |
+
pass
|
14 |
+
|
15 |
+
|
16 |
+
class RAGStep(ABC):
|
17 |
+
def __init__(self, mock: bool = False) -> None:
|
18 |
+
self._mock = mock
|
19 |
+
|
20 |
+
@abstractmethod
|
21 |
+
def generate(self, query: Query, *args, **kwargs) -> Any:
|
22 |
+
pass
|
rag_demo/rag/prompt_templates.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import PromptTemplate
|
2 |
+
|
3 |
+
from .base import PromptTemplateFactory
|
4 |
+
|
5 |
+
|
6 |
+
class QueryExpansionTemplate(PromptTemplateFactory):
|
7 |
+
prompt: str = """You are an AI language model assistant. Your task is to generate {expand_to_n}
|
8 |
+
different versions of the given user question to retrieve relevant documents from a vector
|
9 |
+
database. By generating multiple perspectives on the user question, your goal is to help
|
10 |
+
the user overcome some of the limitations of the distance-based similarity search.
|
11 |
+
Provide these alternative questions seperated by '{separator}'.
|
12 |
+
Original question: {question}"""
|
13 |
+
|
14 |
+
@property
|
15 |
+
def separator(self) -> str:
|
16 |
+
return "#next-question#"
|
17 |
+
|
18 |
+
def create_template(self, expand_to_n: int) -> PromptTemplate:
|
19 |
+
return PromptTemplate(
|
20 |
+
template=self.prompt,
|
21 |
+
input_variables=["question"],
|
22 |
+
partial_variables={
|
23 |
+
"separator": self.separator,
|
24 |
+
"expand_to_n": expand_to_n,
|
25 |
+
},
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
class AnswerGenerationTemplate(PromptTemplateFactory):
|
30 |
+
prompt: str = """You are an AI language model assistant. Your task is to generate an answer to the given user question based on the provided context.
|
31 |
+
Context: {context}
|
32 |
+
Question: {question}
|
33 |
+
|
34 |
+
Give your answer in markdown format if needed, for example if a table is the best way to answer the question, or if titles and subheadings are needed.
|
35 |
+
Give only your answer, do not include any other text like 'Certainly! Here is the answer:' or 'The answer is:' or anything similar."""
|
36 |
+
|
37 |
+
def create_template(self, context: str, question: str) -> str:
|
38 |
+
return self.prompt.format(context=context, question=question)
|
rag_demo/rag/query_expansion.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
from huggingface_hub import InferenceClient
|
5 |
+
|
6 |
+
from rag_demo.rag.base.query import Query
|
7 |
+
from rag_demo.rag.base.template_factory import RAGStep
|
8 |
+
from rag_demo.rag.prompt_templates import QueryExpansionTemplate
|
9 |
+
|
10 |
+
|
11 |
+
class QueryExpansion(RAGStep):
|
12 |
+
def generate(self, query: Query, expand_to_n: int) -> Any:
|
13 |
+
api = InferenceClient(
|
14 |
+
model="Qwen/Qwen2.5-72B-Instruct",
|
15 |
+
token=os.getenv("HF_API_TOKEN"),
|
16 |
+
)
|
17 |
+
query_expansion_template = QueryExpansionTemplate()
|
18 |
+
prompt = query_expansion_template.create_template(expand_to_n - 1)
|
19 |
+
response = api.chat_completion(
|
20 |
+
[
|
21 |
+
{
|
22 |
+
"role": "user",
|
23 |
+
"content": prompt.template.format(
|
24 |
+
question=query.content,
|
25 |
+
expand_to_n=expand_to_n,
|
26 |
+
separator=query_expansion_template.separator,
|
27 |
+
),
|
28 |
+
}
|
29 |
+
]
|
30 |
+
)
|
31 |
+
result = response.choices[0].message.content
|
32 |
+
queries_content = result.split(query_expansion_template.separator)
|
33 |
+
queries = [query]
|
34 |
+
queries += [
|
35 |
+
query.replace_content(stripped_content)
|
36 |
+
for content in queries_content
|
37 |
+
if (stripped_content := content.strip())
|
38 |
+
]
|
39 |
+
return queries
|
rag_demo/rag/reranker.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from huggingface_hub import InferenceClient
|
4 |
+
|
5 |
+
from rag_demo.rag.base.query import Query
|
6 |
+
from rag_demo.rag.base.template_factory import RAGStep
|
7 |
+
from rag_demo.preprocessing.embed import EmbeddedChunk
|
8 |
+
|
9 |
+
|
10 |
+
class Reranker(RAGStep):
|
11 |
+
def generate(
|
12 |
+
self, query: Query, chunks: list[EmbeddedChunk], keep_top_k: int
|
13 |
+
) -> list[EmbeddedChunk]:
|
14 |
+
api = InferenceClient(
|
15 |
+
model="intfloat/multilingual-e5-large-instruct",
|
16 |
+
token=os.getenv("HF_API_TOKEN"),
|
17 |
+
)
|
18 |
+
similarity = api.sentence_similarity(
|
19 |
+
query.content, [chunk.content for chunk in chunks]
|
20 |
+
)
|
21 |
+
for chunk, sim in zip(chunks, similarity):
|
22 |
+
chunk.similarity = sim
|
23 |
+
|
24 |
+
return sorted(chunks, key=lambda x: x.similarity, reverse=True)[:keep_top_k]
|
rag_demo/rag/retriever.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import concurrent.futures
|
2 |
+
import os
|
3 |
+
|
4 |
+
from loguru import logger
|
5 |
+
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
6 |
+
from huggingface_hub import InferenceClient
|
7 |
+
|
8 |
+
from rag_demo.preprocessing.base import (
|
9 |
+
EmbeddedChunk,
|
10 |
+
)
|
11 |
+
from rag_demo.rag.base.query import EmbeddedQuery, Query
|
12 |
+
|
13 |
+
from .query_expansion import QueryExpansion
|
14 |
+
from .reranker import Reranker
|
15 |
+
from .prompt_templates import AnswerGenerationTemplate
|
16 |
+
|
17 |
+
from dotenv import load_dotenv
|
18 |
+
|
19 |
+
load_dotenv()
|
20 |
+
|
21 |
+
|
22 |
+
def flatten(nested_list: list) -> list:
|
23 |
+
"""Flatten a list of lists into a single list."""
|
24 |
+
|
25 |
+
return [item for sublist in nested_list for item in sublist]
|
26 |
+
|
27 |
+
|
28 |
+
class RAGPipeline:
|
29 |
+
def __init__(self, mock: bool = False) -> None:
|
30 |
+
self._query_expander = QueryExpansion(mock=mock)
|
31 |
+
self._reranker = Reranker(mock=mock)
|
32 |
+
|
33 |
+
def search(
|
34 |
+
self,
|
35 |
+
query: str,
|
36 |
+
k: int = 3,
|
37 |
+
expand_to_n_queries: int = 3,
|
38 |
+
) -> list:
|
39 |
+
query_model = Query.from_str(query)
|
40 |
+
|
41 |
+
n_generated_queries = self._query_expander.generate(
|
42 |
+
query_model, expand_to_n=expand_to_n_queries
|
43 |
+
)
|
44 |
+
logger.info(
|
45 |
+
f"Successfully generated {len(n_generated_queries)} search queries.",
|
46 |
+
)
|
47 |
+
|
48 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
49 |
+
search_tasks = [
|
50 |
+
executor.submit(self._search, _query_model, k)
|
51 |
+
for _query_model in n_generated_queries
|
52 |
+
]
|
53 |
+
|
54 |
+
n_k_documents = [
|
55 |
+
task.result() for task in concurrent.futures.as_completed(search_tasks)
|
56 |
+
]
|
57 |
+
n_k_documents = flatten(n_k_documents)
|
58 |
+
n_k_documents = list(set(n_k_documents))
|
59 |
+
|
60 |
+
logger.info(f"{len(n_k_documents)} documents retrieved successfully")
|
61 |
+
|
62 |
+
if len(n_k_documents) > 0:
|
63 |
+
k_documents = self.rerank(query, chunks=n_k_documents, keep_top_k=k)
|
64 |
+
else:
|
65 |
+
k_documents = []
|
66 |
+
|
67 |
+
return k_documents
|
68 |
+
|
69 |
+
def _search(self, query: Query, k: int = 3) -> list[EmbeddedChunk]:
|
70 |
+
assert k >= 3, "k should be >= 3"
|
71 |
+
|
72 |
+
def _search_data(
|
73 |
+
data_category_odm: type[EmbeddedChunk], embedded_query: EmbeddedQuery
|
74 |
+
) -> list[EmbeddedChunk]:
|
75 |
+
return data_category_odm.search(
|
76 |
+
query_vector=embedded_query.embedding,
|
77 |
+
limit=k,
|
78 |
+
)
|
79 |
+
|
80 |
+
api = InferenceClient(
|
81 |
+
model="intfloat/multilingual-e5-large-instruct",
|
82 |
+
token=os.getenv("HF_API_TOKEN"),
|
83 |
+
)
|
84 |
+
embedded_query: EmbeddedQuery = EmbeddedQuery(
|
85 |
+
embedding=api.feature_extraction(query.content),
|
86 |
+
id=query.id,
|
87 |
+
content=query.content,
|
88 |
+
)
|
89 |
+
|
90 |
+
retrieved_chunks = _search_data(EmbeddedChunk, embedded_query)
|
91 |
+
logger.info(f"{len(retrieved_chunks)} documents retrieved successfully")
|
92 |
+
|
93 |
+
return retrieved_chunks
|
94 |
+
|
95 |
+
def rerank(
|
96 |
+
self, query: str | Query, chunks: list[EmbeddedChunk], keep_top_k: int
|
97 |
+
) -> list[EmbeddedChunk]:
|
98 |
+
if isinstance(query, str):
|
99 |
+
query = Query.from_str(query)
|
100 |
+
|
101 |
+
reranked_documents = self._reranker.generate(
|
102 |
+
query=query, chunks=chunks, keep_top_k=keep_top_k
|
103 |
+
)
|
104 |
+
|
105 |
+
logger.info(f"{len(reranked_documents)} documents reranked successfully.")
|
106 |
+
|
107 |
+
return reranked_documents
|
108 |
+
|
109 |
+
def generate_answer(self, query: str, reranked_chunks: list[EmbeddedChunk]) -> str:
|
110 |
+
context = ""
|
111 |
+
for chunk in reranked_chunks:
|
112 |
+
context += "\n Document: "
|
113 |
+
context += chunk.content
|
114 |
+
api = InferenceClient(
|
115 |
+
model="meta-llama/Llama-3.1-8B-Instruct",
|
116 |
+
token=os.getenv("HF_API_TOKEN"),
|
117 |
+
)
|
118 |
+
answer_generation_template = AnswerGenerationTemplate()
|
119 |
+
prompt = answer_generation_template.create_template(context, query)
|
120 |
+
logger.info(prompt)
|
121 |
+
response = api.chat_completion(
|
122 |
+
[{"role": "user", "content": prompt}],
|
123 |
+
max_tokens=8192,
|
124 |
+
)
|
125 |
+
return response.choices[0].message.content
|
126 |
+
|
127 |
+
def rag(self, query: str) -> tuple[str, list[str]]:
|
128 |
+
docs = self.search(query, k=10)
|
129 |
+
reranked_docs = self.rerank(query, docs, keep_top_k=10)
|
130 |
+
return (
|
131 |
+
self.generate_answer(query, reranked_docs),
|
132 |
+
[doc.metadata["filename"].split(".pdf")[0] for doc in reranked_docs],
|
133 |
+
)
|