AdrienB134 commited on
Commit
7fdb8e9
1 Parent(s): 9ae6b81

Upload 54 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. Dockerfile +32 -0
  3. README.md +0 -8
  4. app.py +81 -0
  5. pyproject.toml +18 -0
  6. rag_demo/__init__.py +3 -0
  7. rag_demo/__pycache__/__init__.cpython-311.pyc +0 -0
  8. rag_demo/__pycache__/pipeline.cpython-311.pyc +0 -0
  9. rag_demo/__pycache__/settings.cpython-311.pyc +0 -0
  10. rag_demo/data/test.pdf +0 -0
  11. rag_demo/data/test2.pdf +3 -0
  12. rag_demo/infra/__pycache__/qdrant.cpython-311.pyc +0 -0
  13. rag_demo/infra/qdrant.py +25 -0
  14. rag_demo/pipeline.py +13 -0
  15. rag_demo/preprocessing/__init__.py +5 -0
  16. rag_demo/preprocessing/__pycache__/__init__.cpython-311.pyc +0 -0
  17. rag_demo/preprocessing/__pycache__/chunking.cpython-311.pyc +0 -0
  18. rag_demo/preprocessing/__pycache__/embed.cpython-311.pyc +0 -0
  19. rag_demo/preprocessing/__pycache__/load_to_vectordb.cpython-311.pyc +0 -0
  20. rag_demo/preprocessing/__pycache__/pdf_conversion.cpython-311.pyc +0 -0
  21. rag_demo/preprocessing/base/__init__.py +12 -0
  22. rag_demo/preprocessing/base/__pycache__/__init__.cpython-311.pyc +0 -0
  23. rag_demo/preprocessing/base/__pycache__/chunk.cpython-311.pyc +0 -0
  24. rag_demo/preprocessing/base/__pycache__/document.cpython-311.pyc +0 -0
  25. rag_demo/preprocessing/base/__pycache__/embedded_chunk.cpython-311.pyc +0 -0
  26. rag_demo/preprocessing/base/__pycache__/vectordb.cpython-311.pyc +0 -0
  27. rag_demo/preprocessing/base/chunk.py +13 -0
  28. rag_demo/preprocessing/base/document.py +19 -0
  29. rag_demo/preprocessing/base/embedded_chunk.py +34 -0
  30. rag_demo/preprocessing/base/embeddings.py +145 -0
  31. rag_demo/preprocessing/base/vectordb.py +289 -0
  32. rag_demo/preprocessing/chunking.py +26 -0
  33. rag_demo/preprocessing/embed.py +57 -0
  34. rag_demo/preprocessing/load_to_vectordb.py +30 -0
  35. rag_demo/preprocessing/pdf_conversion.py +33 -0
  36. rag_demo/rag/__pycache__/prompt_templates.cpython-311.pyc +0 -0
  37. rag_demo/rag/__pycache__/query_expansion.cpython-311.pyc +0 -0
  38. rag_demo/rag/__pycache__/reranker.cpython-311.pyc +0 -0
  39. rag_demo/rag/__pycache__/retriever.cpython-311.pyc +0 -0
  40. rag_demo/rag/base/__init__.py +3 -0
  41. rag_demo/rag/base/__pycache__/__init__.cpython-311.pyc +0 -0
  42. rag_demo/rag/base/__pycache__/query.cpython-311.pyc +0 -0
  43. rag_demo/rag/base/__pycache__/template_factory.cpython-311.pyc +0 -0
  44. rag_demo/rag/base/base.py +22 -0
  45. rag_demo/rag/base/query.py +29 -0
  46. rag_demo/rag/base/template_factory.py +22 -0
  47. rag_demo/rag/prompt_templates.py +38 -0
  48. rag_demo/rag/query_expansion.py +39 -0
  49. rag_demo/rag/reranker.py +24 -0
  50. rag_demo/rag/retriever.py +133 -0
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ rag_demo/data/test2.pdf filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.10 as base image
2
+ FROM python:3.10-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies and uv
8
+ RUN apt-get update && apt-get install -y \
9
+ poppler-utils \
10
+ curl \
11
+ && rm -rf /var/lib/apt/lists/* \
12
+ && curl -LsSf https://astral.sh/uv/install.sh | sh \
13
+ && echo 'export PATH="$HOME/.local/bin:$PATH"' >> ~/.bashrc \
14
+ && . ~/.bashrc
15
+
16
+ # Copy requirements first to leverage Docker cache
17
+ COPY pyproject.toml .
18
+
19
+ # Install Python dependencies using uv
20
+ RUN . ~/.bashrc && uv pip install -r pyproject.toml --system
21
+
22
+ # Copy the rest of the application
23
+ COPY . .
24
+
25
+ # Create directories for uploads and embeddings if they don't exist
26
+ RUN mkdir -p uploads embeddings
27
+
28
+ # Expose the port the app runs on
29
+ EXPOSE 7860
30
+
31
+ # Change to rag_demo directory and run the app
32
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,8 +0,0 @@
1
- ---
2
- title: matriv-rag-demo
3
- colorFrom: blue
4
- colorTo: red
5
- sdk: docker
6
- app_file: app.py
7
- pinned: false
8
- ---
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Request
2
+ from fastapi.templating import Jinja2Templates
3
+ from fastapi.responses import HTMLResponse
4
+ from fastapi.staticfiles import StaticFiles
5
+ from pydantic import BaseModel
6
+ import os
7
+ from rag_demo.pipeline import process_pdf
8
+ import nest_asyncio
9
+ from rag_demo.rag.retriever import RAGPipeline
10
+ from loguru import logger
11
+
12
+ app = FastAPI()
13
+
14
+ # Apply nest_asyncio at the start of the application
15
+ nest_asyncio.apply()
16
+
17
+ # Create templates directory if it doesn't exist
18
+ templates = Jinja2Templates(directory="templates")
19
+
20
+ app.mount("/static", StaticFiles(directory="static"), name="static")
21
+
22
+
23
+ class ChatRequest(BaseModel):
24
+ question: str
25
+
26
+
27
+ @app.get("/", response_class=HTMLResponse)
28
+ async def upload_page(request: Request):
29
+ return templates.TemplateResponse("upload.html", {"request": request})
30
+
31
+
32
+ @app.get("/chat", response_class=HTMLResponse)
33
+ async def chat_page(request: Request):
34
+ return templates.TemplateResponse("chat.html", {"request": request})
35
+
36
+
37
+ @app.post("/upload")
38
+ async def upload_pdf(request: Request, file: UploadFile = File(...)):
39
+ try:
40
+ # Create uploads directory if it doesn't exist
41
+ os.makedirs("data", exist_ok=True)
42
+
43
+ file_path = f"data/{file.filename}"
44
+ with open(file_path, "wb") as buffer:
45
+ content = await file.read()
46
+ buffer.write(content)
47
+
48
+ # Process the PDF file with proper await statements
49
+ await process_pdf(file_path)
50
+
51
+ # Return template response with success message
52
+ return templates.TemplateResponse(
53
+ "upload.html",
54
+ {
55
+ "request": request,
56
+ "message": f"Successfully processed {file.filename}",
57
+ "processing": False,
58
+ },
59
+ )
60
+ except Exception as e:
61
+ return templates.TemplateResponse(
62
+ "upload.html", {"request": request, "error": str(e), "processing": False}
63
+ )
64
+
65
+
66
+ @app.post("/chat")
67
+ async def chat(chat_request: ChatRequest):
68
+ rag_pipeline = RAGPipeline()
69
+ try:
70
+ answer = rag_pipeline.rag(chat_request.question)
71
+ print(answer)
72
+ logger.info(answer)
73
+ return {"answer": answer}
74
+ except Exception as e:
75
+ return {"error": str(e)}
76
+
77
+
78
+ if __name__ == "__main__":
79
+ import uvicorn
80
+
81
+ uvicorn.run(app, host="0.0.0.0", port=7860)
pyproject.toml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "rag-base"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ dependencies = [
8
+ "loguru>=0.7.2",
9
+ "langchain>=0.3.9",
10
+ "marker-pdf>=1.0.2",
11
+ "qdrant-client[fastembed]>=1.12.1",
12
+ "fastapi>=0.115.6",
13
+ "pydantic>=2.10.3",
14
+ "python-multipart>=0.0.19",
15
+ "uvicorn>=0.32.1",
16
+ "huggingface-hub>=0.26.3",
17
+ "llama-parse>=0.5.17",
18
+ ]
rag_demo/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .infra.qdrant import connection
2
+
3
+ __all__ = ["connection"]
rag_demo/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (263 Bytes). View file
 
rag_demo/__pycache__/pipeline.cpython-311.pyc ADDED
Binary file (738 Bytes). View file
 
rag_demo/__pycache__/settings.cpython-311.pyc ADDED
Binary file (1.99 kB). View file
 
rag_demo/data/test.pdf ADDED
Binary file (344 kB). View file
 
rag_demo/data/test2.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3041eb7dd274b02a2f18049891dc3f184dff4151796f225b92cd34d676ba923
3
+ size 1962780
rag_demo/infra/__pycache__/qdrant.cpython-311.pyc ADDED
Binary file (1.32 kB). View file
 
rag_demo/infra/qdrant.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from loguru import logger
2
+ from qdrant_client import QdrantClient
3
+ from qdrant_client.http.exceptions import UnexpectedResponse
4
+
5
+
6
+ class QdrantDatabaseConnector:
7
+ _instance: QdrantClient | None = None
8
+
9
+ def __new__(cls, *args, **kwargs) -> QdrantClient:
10
+ if cls._instance is None:
11
+ try:
12
+ cls._instance = QdrantClient(":memory:")
13
+
14
+ logger.info(f"Connection to Qdrant DB with URI successful")
15
+ except:
16
+ logger.exception(
17
+ "Couldn't connect to Qdrant.",
18
+ )
19
+
20
+ raise
21
+
22
+ return cls._instance
23
+
24
+
25
+ connection = QdrantDatabaseConnector()
rag_demo/pipeline.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rag_demo.preprocessing import (
2
+ convert_pdf_to_text,
3
+ load_to_vector_db,
4
+ chunk_and_embed,
5
+ )
6
+ from loguru import logger
7
+
8
+
9
+ def process_pdf(file_path: str):
10
+ convert = convert_pdf_to_text([file_path])
11
+ embedded_chunks = chunk_and_embed([convert])
12
+ load_to_vector_db(embedded_chunks)
13
+ return True
rag_demo/preprocessing/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .pdf_conversion import convert_pdf_to_text
2
+ from .load_to_vectordb import load_to_vector_db
3
+ from .embed import chunk_and_embed
4
+
5
+ __all__ = ["convert_pdf_to_text", "load_to_vector_db", "chunk_and_embed"]
rag_demo/preprocessing/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (441 Bytes). View file
 
rag_demo/preprocessing/__pycache__/chunking.cpython-311.pyc ADDED
Binary file (1.25 kB). View file
 
rag_demo/preprocessing/__pycache__/embed.cpython-311.pyc ADDED
Binary file (3.53 kB). View file
 
rag_demo/preprocessing/__pycache__/load_to_vectordb.cpython-311.pyc ADDED
Binary file (2.39 kB). View file
 
rag_demo/preprocessing/__pycache__/pdf_conversion.cpython-311.pyc ADDED
Binary file (1.8 kB). View file
 
rag_demo/preprocessing/base/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .document import Document, CleanedDocument
2
+ from .chunk import Chunk
3
+ from .embedded_chunk import EmbeddedChunk
4
+ from .vectordb import VectorBaseDocument
5
+
6
+ __all__ = [
7
+ "Document",
8
+ "CleanedDocument",
9
+ "Chunk",
10
+ "EmbeddedChunk",
11
+ "VectorBaseDocument",
12
+ ]
rag_demo/preprocessing/base/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (528 Bytes). View file
 
rag_demo/preprocessing/base/__pycache__/chunk.cpython-311.pyc ADDED
Binary file (927 Bytes). View file
 
rag_demo/preprocessing/base/__pycache__/document.cpython-311.pyc ADDED
Binary file (1.12 kB). View file
 
rag_demo/preprocessing/base/__pycache__/embedded_chunk.cpython-311.pyc ADDED
Binary file (2.04 kB). View file
 
rag_demo/preprocessing/base/__pycache__/vectordb.cpython-311.pyc ADDED
Binary file (16.7 kB). View file
 
rag_demo/preprocessing/base/chunk.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+ from typing import Optional
3
+
4
+ from pydantic import UUID4, Field
5
+
6
+ from rag_demo.preprocessing.base.vectordb import VectorBaseDocument
7
+
8
+
9
+ class Chunk(VectorBaseDocument, ABC):
10
+ content: str
11
+ document_id: UUID4
12
+ chunk_id: UUID4
13
+ metadata: dict = Field(default_factory=dict)
rag_demo/preprocessing/base/document.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+ from typing import Optional
3
+
4
+ from pydantic import UUID4, BaseModel
5
+
6
+ from .vectordb import VectorBaseDocument
7
+
8
+
9
+ class CleanedDocument(VectorBaseDocument, ABC):
10
+ content: str
11
+ doc_id: UUID4
12
+ doc_title: str
13
+ # doc_url: str
14
+
15
+
16
+ class Document(BaseModel):
17
+ text: str
18
+ document_id: UUID4
19
+ metadata: dict
rag_demo/preprocessing/base/embedded_chunk.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+
3
+ from pydantic import UUID4, Field
4
+
5
+
6
+ from .vectordb import VectorBaseDocument
7
+
8
+
9
+ class EmbeddedChunk(VectorBaseDocument, ABC):
10
+ content: str
11
+ embedding: list[float] | None
12
+ document_id: UUID4
13
+ chunk_id: UUID4
14
+ metadata: dict = Field(default_factory=dict)
15
+ similarity: float | None
16
+
17
+ @classmethod
18
+ def to_context(cls, chunks: list["EmbeddedChunk"]) -> str:
19
+ context = ""
20
+ for i, chunk in enumerate(chunks):
21
+ context += f"""
22
+ Chunk {i + 1}:
23
+ Type: {chunk.__class__.__name__}
24
+ Document ID: {chunk.document_id}
25
+ Chunk ID: {chunk.chunk_id}
26
+ Content: {chunk.content}\n
27
+ """
28
+
29
+ return context
30
+
31
+ class Config:
32
+ name = "embedded_documents"
33
+ category = "Document"
34
+ use_vector_index = True
rag_demo/preprocessing/base/embeddings.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import cached_property
2
+ from pathlib import Path
3
+ from typing import Optional, ClassVar
4
+ from threading import Lock
5
+
6
+ import numpy as np
7
+ from loguru import logger
8
+ from numpy.typing import NDArray
9
+ from sentence_transformers.SentenceTransformer import SentenceTransformer
10
+ from transformers import AutoTokenizer
11
+
12
+ from rag_demo.settings import settings
13
+
14
+
15
+ class SingletonMeta(type):
16
+ """
17
+ This is a thread-safe implementation of Singleton.
18
+ """
19
+
20
+ _instances: ClassVar = {}
21
+
22
+ _lock: Lock = Lock()
23
+
24
+ """
25
+ We now have a lock object that will be used to synchronize threads during
26
+ first access to the Singleton.
27
+ """
28
+
29
+ def __call__(cls, *args, **kwargs):
30
+ """
31
+ Possible changes to the value of the `__init__` argument do not affect
32
+ the returned instance.
33
+ """
34
+ # Now, imagine that the program has just been launched. Since there's no
35
+ # Singleton instance yet, multiple threads can simultaneously pass the
36
+ # previous conditional and reach this point almost at the same time. The
37
+ # first of them will acquire lock and will proceed further, while the
38
+ # rest will wait here.
39
+ with cls._lock:
40
+ # The first thread to acquire the lock, reaches this conditional,
41
+ # goes inside and creates the Singleton instance. Once it leaves the
42
+ # lock block, a thread that might have been waiting for the lock
43
+ # release may then enter this section. But since the Singleton field
44
+ # is already initialized, the thread won't create a new object.
45
+ if cls not in cls._instances:
46
+ instance = super().__call__(*args, **kwargs)
47
+ cls._instances[cls] = instance
48
+
49
+ return cls._instances[cls]
50
+
51
+
52
+ class EmbeddingModelSingleton(metaclass=SingletonMeta):
53
+ """
54
+ A singleton class that provides a pre-trained transformer model for generating embeddings of input text.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ model_id: str = settings.TEXT_EMBEDDING_MODEL_ID,
60
+ device: str = settings.RAG_MODEL_DEVICE,
61
+ cache_dir: Optional[Path] = None,
62
+ ) -> None:
63
+ self._model_id = model_id
64
+ self._device = device
65
+
66
+ self._model = SentenceTransformer(
67
+ self._model_id,
68
+ device=self._device,
69
+ cache_folder=str(cache_dir) if cache_dir else None,
70
+ )
71
+ self._model.eval()
72
+
73
+ @property
74
+ def model_id(self) -> str:
75
+ """
76
+ Returns the identifier of the pre-trained transformer model to use.
77
+
78
+ Returns:
79
+ str: The identifier of the pre-trained transformer model to use.
80
+ """
81
+
82
+ return self._model_id
83
+
84
+ @cached_property
85
+ def embedding_size(self) -> int:
86
+ """
87
+ Returns the size of the embeddings generated by the pre-trained transformer model.
88
+
89
+ Returns:
90
+ int: The size of the embeddings generated by the pre-trained transformer model.
91
+ """
92
+
93
+ dummy_embedding = self._model.encode("")
94
+
95
+ return dummy_embedding.shape[0]
96
+
97
+ @property
98
+ def max_input_length(self) -> int:
99
+ """
100
+ Returns the maximum length of input text to tokenize.
101
+
102
+ Returns:
103
+ int: The maximum length of input text to tokenize.
104
+ """
105
+
106
+ return self._model.max_seq_length
107
+
108
+ @property
109
+ def tokenizer(self) -> AutoTokenizer:
110
+ """
111
+ Returns the tokenizer used to tokenize input text.
112
+
113
+ Returns:
114
+ AutoTokenizer: The tokenizer used to tokenize input text.
115
+ """
116
+
117
+ return self._model.tokenizer
118
+
119
+ def __call__(
120
+ self, input_text: str | list[str], to_list: bool = True
121
+ ) -> NDArray[np.float32] | list[float] | list[list[float]]:
122
+ """
123
+ Generates embeddings for the input text using the pre-trained transformer model.
124
+
125
+ Args:
126
+ input_text (str): The input text to generate embeddings for.
127
+ to_list (bool): Whether to return the embeddings as a list or numpy array. Defaults to True.
128
+
129
+ Returns:
130
+ Union[np.ndarray, list]: The embeddings generated for the input text.
131
+ """
132
+
133
+ try:
134
+ embeddings = self._model.encode(input_text)
135
+ except Exception:
136
+ logger.error(
137
+ f"Error generating embeddings for {self._model_id=} and {input_text=}"
138
+ )
139
+
140
+ return [] if to_list else np.array([])
141
+
142
+ if to_list:
143
+ embeddings = embeddings.tolist()
144
+
145
+ return embeddings
rag_demo/preprocessing/base/vectordb.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ from abc import ABC
3
+ from typing import Any, Callable, Dict, Generic, Type, TypeVar
4
+ from uuid import UUID
5
+
6
+ import numpy as np
7
+ from loguru import logger
8
+ from pydantic import UUID4, BaseModel, Field
9
+ from qdrant_client.http import exceptions
10
+ from qdrant_client.http.models import Distance, VectorParams
11
+ from qdrant_client.models import CollectionInfo, PointStruct, Record
12
+
13
+
14
+ from rag_demo.infra.qdrant import connection
15
+
16
+ T = TypeVar("T", bound="VectorBaseDocument")
17
+
18
+ EMBEDDING_SIZE = 1024
19
+
20
+
21
+ class VectorBaseDocument(BaseModel, Generic[T], ABC):
22
+ id: UUID4 = Field(default_factory=uuid.uuid4)
23
+
24
+ def __eq__(self, value: object) -> bool:
25
+ if not isinstance(value, self.__class__):
26
+ return False
27
+
28
+ return self.id == value.id
29
+
30
+ def __hash__(self) -> int:
31
+ return hash(self.id)
32
+
33
+ @classmethod
34
+ def from_record(cls: Type[T], point: Record) -> T:
35
+ _id = UUID(point.id, version=4)
36
+ payload = point.payload or {}
37
+
38
+ attributes = {
39
+ "id": _id,
40
+ **payload,
41
+ }
42
+ if cls._has_class_attribute("embedding"):
43
+ attributes["embedding"] = point.vector or None
44
+
45
+ return cls(**attributes)
46
+
47
+ def to_point(self: T, **kwargs) -> PointStruct:
48
+ exclude_unset = kwargs.pop("exclude_unset", False)
49
+ by_alias = kwargs.pop("by_alias", True)
50
+
51
+ payload = self.model_dump(
52
+ exclude_unset=exclude_unset, by_alias=by_alias, **kwargs
53
+ )
54
+
55
+ _id = str(payload.pop("id"))
56
+ vector = payload.pop("embedding", {})
57
+ if vector and isinstance(vector, np.ndarray):
58
+ vector = vector.tolist()
59
+
60
+ return PointStruct(id=_id, vector=vector, payload=payload)
61
+
62
+ def model_dump(self: T, **kwargs) -> dict:
63
+ dict_ = super().model_dump(**kwargs)
64
+
65
+ dict_ = self._uuid_to_str(dict_)
66
+
67
+ return dict_
68
+
69
+ def _uuid_to_str(self, item: Any) -> Any:
70
+ if isinstance(item, dict):
71
+ for key, value in item.items():
72
+ if isinstance(value, UUID):
73
+ item[key] = str(value)
74
+ elif isinstance(value, list):
75
+ item[key] = [self._uuid_to_str(v) for v in value]
76
+ elif isinstance(value, dict):
77
+ item[key] = {k: self._uuid_to_str(v) for k, v in value.items()}
78
+
79
+ return item
80
+
81
+ @classmethod
82
+ def bulk_insert(cls: Type[T], documents: list["VectorBaseDocument"]) -> bool:
83
+ try:
84
+ cls._bulk_insert(documents)
85
+ logger.info(
86
+ f"Successfully inserted {len(documents)} documents into {cls.get_collection_name()}"
87
+ )
88
+
89
+ except Exception as e:
90
+ logger.error(f"Error inserting documents: {e}")
91
+ logger.info(
92
+ f"Collection '{cls.get_collection_name()}' does not exist. Trying to create the collection and reinsert the documents."
93
+ )
94
+
95
+ cls.create_collection()
96
+
97
+ try:
98
+ cls._bulk_insert(documents)
99
+ except Exception as e:
100
+ logger.error(f"Error inserting documents: {e}")
101
+ logger.error(
102
+ f"Failed to insert documents in '{cls.get_collection_name()}'."
103
+ )
104
+
105
+ return False
106
+
107
+ return True
108
+
109
+ @classmethod
110
+ def _bulk_insert(cls: Type[T], documents: list["VectorBaseDocument"]) -> None:
111
+ points = [doc.to_point() for doc in documents]
112
+
113
+ connection.upsert(collection_name=cls.get_collection_name(), points=points)
114
+
115
+ @classmethod
116
+ def bulk_find(
117
+ cls: Type[T], limit: int = 10, **kwargs
118
+ ) -> tuple[list[T], UUID | None]:
119
+ try:
120
+ documents, next_offset = cls._bulk_find(limit=limit, **kwargs)
121
+ except exceptions.UnexpectedResponse:
122
+ logger.error(
123
+ f"Failed to search documents in '{cls.get_collection_name()}'."
124
+ )
125
+
126
+ documents, next_offset = [], None
127
+
128
+ return documents, next_offset
129
+
130
+ @classmethod
131
+ def _bulk_find(
132
+ cls: Type[T], limit: int = 10, **kwargs
133
+ ) -> tuple[list[T], UUID | None]:
134
+ collection_name = cls.get_collection_name()
135
+
136
+ offset = kwargs.pop("offset", None)
137
+ offset = str(offset) if offset else None
138
+
139
+ records, next_offset = connection.scroll(
140
+ collection_name=collection_name,
141
+ limit=limit,
142
+ with_payload=kwargs.pop("with_payload", True),
143
+ with_vectors=kwargs.pop("with_vectors", False),
144
+ offset=offset,
145
+ **kwargs,
146
+ )
147
+ documents = [cls.from_record(record) for record in records]
148
+ if next_offset is not None:
149
+ next_offset = UUID(next_offset, version=4)
150
+
151
+ return documents, next_offset
152
+
153
+ @classmethod
154
+ def search(cls: Type[T], query_vector: list, limit: int = 10, **kwargs) -> list[T]:
155
+ try:
156
+ documents = cls._search(query_vector=query_vector, limit=limit, **kwargs)
157
+ except exceptions.UnexpectedResponse:
158
+ logger.error(
159
+ f"Failed to search documents in '{cls.get_collection_name()}'."
160
+ )
161
+
162
+ documents = []
163
+
164
+ return documents
165
+
166
+ @classmethod
167
+ def _search(cls: Type[T], query_vector: list, limit: int = 10, **kwargs) -> list[T]:
168
+ collection_name = cls.get_collection_name()
169
+ records = connection.search(
170
+ collection_name=collection_name,
171
+ query_vector=query_vector,
172
+ limit=limit,
173
+ with_payload=kwargs.pop("with_payload", True),
174
+ with_vectors=kwargs.pop("with_vectors", False),
175
+ **kwargs,
176
+ )
177
+ documents = [cls.from_record(record) for record in records]
178
+
179
+ return documents
180
+
181
+ @classmethod
182
+ def get_or_create_collection(cls: Type[T]) -> CollectionInfo:
183
+ collection_name = cls.get_collection_name()
184
+
185
+ try:
186
+ return connection.get_collection(collection_name=collection_name)
187
+ except exceptions.UnexpectedResponse:
188
+ use_vector_index = cls.get_use_vector_index()
189
+
190
+ collection_created = cls._create_collection(
191
+ collection_name=collection_name, use_vector_index=use_vector_index
192
+ )
193
+ if collection_created is False:
194
+ raise RuntimeError(
195
+ f"Couldn't create collection {collection_name}"
196
+ ) from None
197
+
198
+ return connection.get_collection(collection_name=collection_name)
199
+
200
+ @classmethod
201
+ def create_collection(cls: Type[T]) -> bool:
202
+ collection_name = cls.get_collection_name()
203
+ use_vector_index = cls.get_use_vector_index()
204
+ logger.info(
205
+ f"Creating collection {collection_name} with use_vector_index={use_vector_index}"
206
+ )
207
+ return cls._create_collection(
208
+ collection_name=collection_name, use_vector_index=use_vector_index
209
+ )
210
+
211
+ @classmethod
212
+ def _create_collection(
213
+ cls, collection_name: str, use_vector_index: bool = True
214
+ ) -> bool:
215
+ if use_vector_index is True:
216
+ vectors_config = VectorParams(size=EMBEDDING_SIZE, distance=Distance.COSINE)
217
+ else:
218
+ vectors_config = {}
219
+
220
+ return connection.create_collection(
221
+ collection_name=collection_name, vectors_config=vectors_config
222
+ )
223
+
224
+ @classmethod
225
+ def get_collection_name(cls: Type[T]) -> str:
226
+ if not hasattr(cls, "Config") or not hasattr(cls.Config, "name"):
227
+ raise Exception(
228
+ f"The class {cls} should define a Config class with the 'name' property that reflects the collection's name."
229
+ )
230
+
231
+ return cls.Config.name
232
+
233
+ @classmethod
234
+ def get_use_vector_index(cls: Type[T]) -> bool:
235
+ if not hasattr(cls, "Config") or not hasattr(cls.Config, "use_vector_index"):
236
+ return True
237
+
238
+ return cls.Config.use_vector_index
239
+
240
+ @classmethod
241
+ def group_by_class(
242
+ cls: Type["VectorBaseDocument"], documents: list["VectorBaseDocument"]
243
+ ) -> Dict["VectorBaseDocument", list["VectorBaseDocument"]]:
244
+ return cls._group_by(documents, selector=lambda doc: doc.__class__)
245
+
246
+ @classmethod
247
+ def _group_by(
248
+ cls: Type[T], documents: list[T], selector: Callable[[T], Any]
249
+ ) -> Dict[Any, list[T]]:
250
+ grouped = {}
251
+ for doc in documents:
252
+ key = selector(doc)
253
+
254
+ if key not in grouped:
255
+ grouped[key] = []
256
+ grouped[key].append(doc)
257
+
258
+ return grouped
259
+
260
+ @classmethod
261
+ def collection_name_to_class(
262
+ cls: Type["VectorBaseDocument"], collection_name: str
263
+ ) -> type["VectorBaseDocument"]:
264
+ for subclass in cls.__subclasses__():
265
+ try:
266
+ if subclass.get_collection_name() == collection_name:
267
+ return subclass
268
+ except Exception:
269
+ pass
270
+
271
+ try:
272
+ return subclass.collection_name_to_class(collection_name)
273
+ except ValueError:
274
+ continue
275
+
276
+ raise ValueError(f"No subclass found for collection name: {collection_name}")
277
+
278
+ @classmethod
279
+ def _has_class_attribute(cls: Type[T], attribute_name: str) -> bool:
280
+ if attribute_name in cls.__annotations__:
281
+ return True
282
+
283
+ for base in cls.__bases__:
284
+ if hasattr(base, "_has_class_attribute") and base._has_class_attribute(
285
+ attribute_name
286
+ ):
287
+ return True
288
+
289
+ return False
rag_demo/preprocessing/chunking.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from uuid import uuid4
2
+
3
+ from langchain.text_splitter import MarkdownTextSplitter
4
+ from rag_demo.preprocessing.base import Chunk
5
+ from rag_demo.preprocessing.base import Document
6
+
7
+
8
+ def chunk_text(
9
+ document: Document, chunk_size: int = 500, chunk_overlap: int = 50
10
+ ) -> list[Chunk]:
11
+ text_splitter = MarkdownTextSplitter(
12
+ chunk_size=chunk_size, chunk_overlap=chunk_overlap
13
+ )
14
+ chunks = text_splitter.split_text(document.text)
15
+ result = []
16
+ for chunk in chunks:
17
+ result.append(
18
+ Chunk(
19
+ content=chunk,
20
+ document_id=document.document_id,
21
+ chunk_id=uuid4(),
22
+ metadata=document.metadata,
23
+ )
24
+ )
25
+
26
+ return result
rag_demo/preprocessing/embed.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import Annotated
2
+ from typing import Generator
3
+ from .base import Chunk
4
+ from .base import EmbeddedChunk
5
+ from .chunking import chunk_text
6
+ from huggingface_hub import InferenceClient
7
+ import os
8
+ from dotenv import load_dotenv
9
+ from uuid import uuid4
10
+ from loguru import logger
11
+
12
+ load_dotenv()
13
+
14
+
15
+ def batch(list_: list, size: int) -> Generator[list, None, None]:
16
+ yield from (list_[i : i + size] for i in range(0, len(list_), size))
17
+
18
+
19
+ def embed_chunks(chunks: list[Chunk]) -> list[EmbeddedChunk]:
20
+ api = InferenceClient(
21
+ model="intfloat/multilingual-e5-large-instruct",
22
+ token=os.getenv("HF_API_TOKEN"),
23
+ )
24
+ logger.info(f"Embedding {len(chunks)} chunks")
25
+ embedded_chunks = []
26
+ for chunk in chunks:
27
+ try:
28
+ embedded_chunks.append(
29
+ EmbeddedChunk(
30
+ id=uuid4(),
31
+ content=chunk.content,
32
+ embedding=api.feature_extraction(chunk.content),
33
+ document_id=chunk.document_id,
34
+ chunk_id=chunk.chunk_id,
35
+ metadata=chunk.metadata,
36
+ similarity=None,
37
+ )
38
+ )
39
+ except Exception as e:
40
+ logger.error(f"Error embedding chunk: {e}")
41
+ logger.info(f"{len(embedded_chunks)} chunks embedded successfully")
42
+
43
+ return embedded_chunks
44
+
45
+
46
+ def chunk_and_embed(
47
+ cleaned_documents: Annotated[list, "cleaned_documents"],
48
+ ) -> Annotated[list, "embedded_documents"]:
49
+ embedded_chunks = []
50
+ for document in cleaned_documents:
51
+ chunks = chunk_text(document)
52
+
53
+ for batched_chunks in batch(chunks, 10):
54
+ batched_embedded_chunks = embed_chunks(batched_chunks)
55
+ embedded_chunks.extend(batched_embedded_chunks)
56
+ logger.info(f"{len(embedded_chunks)} chunks embedded successfully")
57
+ return embedded_chunks
rag_demo/preprocessing/load_to_vectordb.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from loguru import logger
2
+ from typing_extensions import Annotated
3
+ from typing import Generator
4
+
5
+ from .base import VectorBaseDocument
6
+
7
+
8
+ def batch(list_: list, size: int) -> Generator[list, None, None]:
9
+ yield from (list_[i : i + size] for i in range(0, len(list_), size))
10
+
11
+
12
+ def load_to_vector_db(
13
+ documents: Annotated[list, "documents"],
14
+ ) -> Annotated[bool, "successful"]:
15
+ logger.info(f"Loading {len(documents)} documents into the vector database.")
16
+
17
+ grouped_documents = VectorBaseDocument.group_by_class(documents)
18
+ for document_class, documents in grouped_documents.items():
19
+ logger.info(f"Loading documents into {document_class.get_collection_name()}")
20
+ for documents_batch in batch(documents, size=4):
21
+ try:
22
+ document_class.bulk_insert(documents_batch)
23
+ except Exception as e:
24
+ logger.error(
25
+ f"Failed to insert documents into {document_class.get_collection_name()}: {e}"
26
+ )
27
+
28
+ return False
29
+
30
+ return True
rag_demo/preprocessing/pdf_conversion.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llama_parse import LlamaParse
2
+ from llama_index.core import SimpleDirectoryReader
3
+ from uuid import uuid4
4
+ from .base import Document
5
+ from loguru import logger
6
+
7
+ from dotenv import load_dotenv
8
+
9
+ load_dotenv()
10
+
11
+
12
+ # set up parser
13
+ parser = LlamaParse(
14
+ api_key="llx-TN6YSXvZdpG0qhJ7rVx9QFg5Zq298RXr7Id7XzXb5Wr4Rnpt",
15
+ result_type="markdown", # "markdown" and "text" are available
16
+ )
17
+
18
+
19
+ def convert_pdf_to_text(filepaths: list[str]) -> Document:
20
+ file_extractor = {".pdf": parser}
21
+ # use SimpleDirectoryReader to parse our file
22
+
23
+ documents = SimpleDirectoryReader(
24
+ input_files=filepaths, file_extractor=file_extractor
25
+ ).load_data()
26
+
27
+ logger.info("Converted 1 documents")
28
+
29
+ return Document(
30
+ document_id=uuid4(),
31
+ text=" ".join(document.text for document in documents),
32
+ metadata={"filename": filepaths[0].split("/")[-1]},
33
+ )
rag_demo/rag/__pycache__/prompt_templates.cpython-311.pyc ADDED
Binary file (2.75 kB). View file
 
rag_demo/rag/__pycache__/query_expansion.cpython-311.pyc ADDED
Binary file (2.4 kB). View file
 
rag_demo/rag/__pycache__/reranker.cpython-311.pyc ADDED
Binary file (1.96 kB). View file
 
rag_demo/rag/__pycache__/retriever.cpython-311.pyc ADDED
Binary file (8.21 kB). View file
 
rag_demo/rag/base/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .template_factory import PromptTemplateFactory
2
+
3
+ __all__ = [PromptTemplateFactory]
rag_demo/rag/base/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (283 Bytes). View file
 
rag_demo/rag/base/__pycache__/query.cpython-311.pyc ADDED
Binary file (2.08 kB). View file
 
rag_demo/rag/base/__pycache__/template_factory.cpython-311.pyc ADDED
Binary file (1.64 kB). View file
 
rag_demo/rag/base/base.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any
3
+
4
+ from langchain.prompts import PromptTemplate
5
+ from pydantic import BaseModel
6
+
7
+ from rag_demo.rag.base.query import Query
8
+
9
+
10
+ class PromptTemplateFactory(ABC, BaseModel):
11
+ @abstractmethod
12
+ def create_template(self) -> PromptTemplate:
13
+ pass
14
+
15
+
16
+ class RAGStep(ABC):
17
+ def __init__(self, mock: bool = False) -> None:
18
+ self._mock = mock
19
+
20
+ @abstractmethod
21
+ def generate(self, query: Query, *args, **kwargs) -> Any:
22
+ pass
rag_demo/rag/base/query.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import UUID4, Field
2
+
3
+ from rag_demo.preprocessing.base import VectorBaseDocument
4
+
5
+
6
+ class Query(VectorBaseDocument):
7
+ content: str
8
+ metadata: dict = Field(default_factory=dict)
9
+
10
+ class Config:
11
+ category = "query"
12
+
13
+ @classmethod
14
+ def from_str(cls, query: str) -> "Query":
15
+ return Query(content=query.strip("\n "))
16
+
17
+ def replace_content(self, new_content: str) -> "Query":
18
+ return Query(
19
+ id=self.id,
20
+ content=new_content,
21
+ metadata=self.metadata,
22
+ )
23
+
24
+
25
+ class EmbeddedQuery(Query):
26
+ embedding: list[float]
27
+
28
+ class Config:
29
+ category = "query"
rag_demo/rag/base/template_factory.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any
3
+
4
+ from langchain.prompts import PromptTemplate
5
+ from pydantic import BaseModel
6
+
7
+ from .query import Query
8
+
9
+
10
+ class PromptTemplateFactory(ABC, BaseModel):
11
+ @abstractmethod
12
+ def create_template(self) -> PromptTemplate:
13
+ pass
14
+
15
+
16
+ class RAGStep(ABC):
17
+ def __init__(self, mock: bool = False) -> None:
18
+ self._mock = mock
19
+
20
+ @abstractmethod
21
+ def generate(self, query: Query, *args, **kwargs) -> Any:
22
+ pass
rag_demo/rag/prompt_templates.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import PromptTemplate
2
+
3
+ from .base import PromptTemplateFactory
4
+
5
+
6
+ class QueryExpansionTemplate(PromptTemplateFactory):
7
+ prompt: str = """You are an AI language model assistant. Your task is to generate {expand_to_n}
8
+ different versions of the given user question to retrieve relevant documents from a vector
9
+ database. By generating multiple perspectives on the user question, your goal is to help
10
+ the user overcome some of the limitations of the distance-based similarity search.
11
+ Provide these alternative questions seperated by '{separator}'.
12
+ Original question: {question}"""
13
+
14
+ @property
15
+ def separator(self) -> str:
16
+ return "#next-question#"
17
+
18
+ def create_template(self, expand_to_n: int) -> PromptTemplate:
19
+ return PromptTemplate(
20
+ template=self.prompt,
21
+ input_variables=["question"],
22
+ partial_variables={
23
+ "separator": self.separator,
24
+ "expand_to_n": expand_to_n,
25
+ },
26
+ )
27
+
28
+
29
+ class AnswerGenerationTemplate(PromptTemplateFactory):
30
+ prompt: str = """You are an AI language model assistant. Your task is to generate an answer to the given user question based on the provided context.
31
+ Context: {context}
32
+ Question: {question}
33
+
34
+ Give your answer in markdown format if needed, for example if a table is the best way to answer the question, or if titles and subheadings are needed.
35
+ Give only your answer, do not include any other text like 'Certainly! Here is the answer:' or 'The answer is:' or anything similar."""
36
+
37
+ def create_template(self, context: str, question: str) -> str:
38
+ return self.prompt.format(context=context, question=question)
rag_demo/rag/query_expansion.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any
3
+
4
+ from huggingface_hub import InferenceClient
5
+
6
+ from rag_demo.rag.base.query import Query
7
+ from rag_demo.rag.base.template_factory import RAGStep
8
+ from rag_demo.rag.prompt_templates import QueryExpansionTemplate
9
+
10
+
11
+ class QueryExpansion(RAGStep):
12
+ def generate(self, query: Query, expand_to_n: int) -> Any:
13
+ api = InferenceClient(
14
+ model="Qwen/Qwen2.5-72B-Instruct",
15
+ token=os.getenv("HF_API_TOKEN"),
16
+ )
17
+ query_expansion_template = QueryExpansionTemplate()
18
+ prompt = query_expansion_template.create_template(expand_to_n - 1)
19
+ response = api.chat_completion(
20
+ [
21
+ {
22
+ "role": "user",
23
+ "content": prompt.template.format(
24
+ question=query.content,
25
+ expand_to_n=expand_to_n,
26
+ separator=query_expansion_template.separator,
27
+ ),
28
+ }
29
+ ]
30
+ )
31
+ result = response.choices[0].message.content
32
+ queries_content = result.split(query_expansion_template.separator)
33
+ queries = [query]
34
+ queries += [
35
+ query.replace_content(stripped_content)
36
+ for content in queries_content
37
+ if (stripped_content := content.strip())
38
+ ]
39
+ return queries
rag_demo/rag/reranker.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from huggingface_hub import InferenceClient
4
+
5
+ from rag_demo.rag.base.query import Query
6
+ from rag_demo.rag.base.template_factory import RAGStep
7
+ from rag_demo.preprocessing.embed import EmbeddedChunk
8
+
9
+
10
+ class Reranker(RAGStep):
11
+ def generate(
12
+ self, query: Query, chunks: list[EmbeddedChunk], keep_top_k: int
13
+ ) -> list[EmbeddedChunk]:
14
+ api = InferenceClient(
15
+ model="intfloat/multilingual-e5-large-instruct",
16
+ token=os.getenv("HF_API_TOKEN"),
17
+ )
18
+ similarity = api.sentence_similarity(
19
+ query.content, [chunk.content for chunk in chunks]
20
+ )
21
+ for chunk, sim in zip(chunks, similarity):
22
+ chunk.similarity = sim
23
+
24
+ return sorted(chunks, key=lambda x: x.similarity, reverse=True)[:keep_top_k]
rag_demo/rag/retriever.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import concurrent.futures
2
+ import os
3
+
4
+ from loguru import logger
5
+ from qdrant_client.models import FieldCondition, Filter, MatchValue
6
+ from huggingface_hub import InferenceClient
7
+
8
+ from rag_demo.preprocessing.base import (
9
+ EmbeddedChunk,
10
+ )
11
+ from rag_demo.rag.base.query import EmbeddedQuery, Query
12
+
13
+ from .query_expansion import QueryExpansion
14
+ from .reranker import Reranker
15
+ from .prompt_templates import AnswerGenerationTemplate
16
+
17
+ from dotenv import load_dotenv
18
+
19
+ load_dotenv()
20
+
21
+
22
+ def flatten(nested_list: list) -> list:
23
+ """Flatten a list of lists into a single list."""
24
+
25
+ return [item for sublist in nested_list for item in sublist]
26
+
27
+
28
+ class RAGPipeline:
29
+ def __init__(self, mock: bool = False) -> None:
30
+ self._query_expander = QueryExpansion(mock=mock)
31
+ self._reranker = Reranker(mock=mock)
32
+
33
+ def search(
34
+ self,
35
+ query: str,
36
+ k: int = 3,
37
+ expand_to_n_queries: int = 3,
38
+ ) -> list:
39
+ query_model = Query.from_str(query)
40
+
41
+ n_generated_queries = self._query_expander.generate(
42
+ query_model, expand_to_n=expand_to_n_queries
43
+ )
44
+ logger.info(
45
+ f"Successfully generated {len(n_generated_queries)} search queries.",
46
+ )
47
+
48
+ with concurrent.futures.ThreadPoolExecutor() as executor:
49
+ search_tasks = [
50
+ executor.submit(self._search, _query_model, k)
51
+ for _query_model in n_generated_queries
52
+ ]
53
+
54
+ n_k_documents = [
55
+ task.result() for task in concurrent.futures.as_completed(search_tasks)
56
+ ]
57
+ n_k_documents = flatten(n_k_documents)
58
+ n_k_documents = list(set(n_k_documents))
59
+
60
+ logger.info(f"{len(n_k_documents)} documents retrieved successfully")
61
+
62
+ if len(n_k_documents) > 0:
63
+ k_documents = self.rerank(query, chunks=n_k_documents, keep_top_k=k)
64
+ else:
65
+ k_documents = []
66
+
67
+ return k_documents
68
+
69
+ def _search(self, query: Query, k: int = 3) -> list[EmbeddedChunk]:
70
+ assert k >= 3, "k should be >= 3"
71
+
72
+ def _search_data(
73
+ data_category_odm: type[EmbeddedChunk], embedded_query: EmbeddedQuery
74
+ ) -> list[EmbeddedChunk]:
75
+ return data_category_odm.search(
76
+ query_vector=embedded_query.embedding,
77
+ limit=k,
78
+ )
79
+
80
+ api = InferenceClient(
81
+ model="intfloat/multilingual-e5-large-instruct",
82
+ token=os.getenv("HF_API_TOKEN"),
83
+ )
84
+ embedded_query: EmbeddedQuery = EmbeddedQuery(
85
+ embedding=api.feature_extraction(query.content),
86
+ id=query.id,
87
+ content=query.content,
88
+ )
89
+
90
+ retrieved_chunks = _search_data(EmbeddedChunk, embedded_query)
91
+ logger.info(f"{len(retrieved_chunks)} documents retrieved successfully")
92
+
93
+ return retrieved_chunks
94
+
95
+ def rerank(
96
+ self, query: str | Query, chunks: list[EmbeddedChunk], keep_top_k: int
97
+ ) -> list[EmbeddedChunk]:
98
+ if isinstance(query, str):
99
+ query = Query.from_str(query)
100
+
101
+ reranked_documents = self._reranker.generate(
102
+ query=query, chunks=chunks, keep_top_k=keep_top_k
103
+ )
104
+
105
+ logger.info(f"{len(reranked_documents)} documents reranked successfully.")
106
+
107
+ return reranked_documents
108
+
109
+ def generate_answer(self, query: str, reranked_chunks: list[EmbeddedChunk]) -> str:
110
+ context = ""
111
+ for chunk in reranked_chunks:
112
+ context += "\n Document: "
113
+ context += chunk.content
114
+ api = InferenceClient(
115
+ model="meta-llama/Llama-3.1-8B-Instruct",
116
+ token=os.getenv("HF_API_TOKEN"),
117
+ )
118
+ answer_generation_template = AnswerGenerationTemplate()
119
+ prompt = answer_generation_template.create_template(context, query)
120
+ logger.info(prompt)
121
+ response = api.chat_completion(
122
+ [{"role": "user", "content": prompt}],
123
+ max_tokens=8192,
124
+ )
125
+ return response.choices[0].message.content
126
+
127
+ def rag(self, query: str) -> tuple[str, list[str]]:
128
+ docs = self.search(query, k=10)
129
+ reranked_docs = self.rerank(query, docs, keep_top_k=10)
130
+ return (
131
+ self.generate_answer(query, reranked_docs),
132
+ [doc.metadata["filename"].split(".pdf")[0] for doc in reranked_docs],
133
+ )