JPBianchi commited on
Commit
f85a680
·
1 Parent(s): 94263d8

multiRAG implementation

Browse files
app/engine/chunk_embed.py CHANGED
@@ -1,10 +1,8 @@
1
-
2
-
3
  import os
4
  import pandas as pd
5
  import torch
6
 
7
- from settings import parquet_file
8
 
9
  import tiktoken # tokenizer library for use with OpenAI LLMs
10
  from llama_index.legacy.text_splitter import SentenceSplitter
 
 
 
1
  import os
2
  import pandas as pd
3
  import torch
4
 
5
+ from app.settings import parquet_file
6
 
7
  import tiktoken # tokenizer library for use with OpenAI LLMs
8
  from llama_index.legacy.text_splitter import SentenceSplitter
app/engine/loaders/file.py CHANGED
@@ -3,13 +3,15 @@ import os
3
  # from langchain.document_loaders import PyPDFLoader # deprecated
4
  from langchain_community.document_loaders import PyPDFLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 
6
  from llama_parse import LlamaParse
7
 
8
  from typing import Union, List, Dict
9
 
10
  from abc import ABC, abstractmethod
11
 
12
- class PDFExtractor(ABC):
13
 
14
  def __init__(self, file_or_list: Union[str, List[str]], num_workers: int = 1, verbose: bool = False):
15
  """ We can provide a list of files or a single file """
@@ -40,7 +42,7 @@ class PDFExtractor(ABC):
40
  """
41
  pass
42
 
43
- class _PyPDFLoader(PDFExtractor):
44
 
45
  def extract_text(self):
46
  output_dict = {}
@@ -58,7 +60,7 @@ class _PyPDFLoader(PDFExtractor):
58
  return
59
 
60
 
61
- class _LlamaParse(PDFExtractor):
62
 
63
  def extract_text(self):
64
  # https://github.com/run-llama/llama_parse
@@ -88,18 +90,59 @@ class _LlamaParse(PDFExtractor):
88
  raise NotImplementedError("Not implemented or LlamaParse does not support table extraction")
89
  return
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- def pdf_extractor(extractor_type: str, *args, **kwargs) -> PDFExtractor:
93
- """ Factory function to return the appropriate PDF extractor instance, properly initialized """
 
 
 
 
 
 
 
 
 
 
94
 
95
  if extractor_type == 'PyPDFLoader':
96
  return _PyPDFLoader(*args, **kwargs)
97
 
98
  elif extractor_type == 'LlamaParse':
99
  return _LlamaParse(*args, **kwargs)
 
 
 
 
100
  else:
101
  raise ValueError(f"Unsupported PDF extractor type: {extractor_type}")
102
 
103
 
104
 
105
-
 
 
3
  # from langchain.document_loaders import PyPDFLoader # deprecated
4
  from langchain_community.document_loaders import PyPDFLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain_community.document_loaders.csv_loader import CSVLoader
7
+ # ^ if we want to add CSV support, it will transform every row into a k:v pair
8
  from llama_parse import LlamaParse
9
 
10
  from typing import Union, List, Dict
11
 
12
  from abc import ABC, abstractmethod
13
 
14
+ class Extractor(ABC):
15
 
16
  def __init__(self, file_or_list: Union[str, List[str]], num_workers: int = 1, verbose: bool = False):
17
  """ We can provide a list of files or a single file """
 
42
  """
43
  pass
44
 
45
+ class _PyPDFLoader(Extractor):
46
 
47
  def extract_text(self):
48
  output_dict = {}
 
60
  return
61
 
62
 
63
+ class _LlamaParse(Extractor):
64
 
65
  def extract_text(self):
66
  # https://github.com/run-llama/llama_parse
 
90
  raise NotImplementedError("Not implemented or LlamaParse does not support table extraction")
91
  return
92
 
93
+ class _TXTLoader(Extractor):
94
+
95
+ def extract_text(self):
96
+ output_dict = {}
97
+ for fpath in self.filelist:
98
+ fname = fpath.split('/')[-1]
99
+ output_dict[fname] = [open(fpath, 'r').read()]
100
+ # with pdfs, we use a list of strings, one for each page
101
+ # so we must return a list here, even if it's just one string with everything
102
+ return output_dict
103
+
104
+ def extract_images(self):
105
+ raise NotImplementedError("Not implemented or PyPDFLoader does not support image extraction")
106
+ return
107
+
108
+ def extract_tables(self):
109
+ raise NotImplementedError("Not implemented or PyPDFLoader does not support table extraction")
110
+ return
111
+
112
+ class _CSVLoader(Extractor):
113
+ # mock code for now, as a reminder of what we could do if time allows TODO
114
+ def extract_text(self):
115
+ output_dict = {}
116
+ for fpath in self.filelist:
117
+ fname = fpath.split('/')[-1]
118
+ output_dict[fname] = [CSVLoader(fpath).load()] # << untested!
119
 
120
+ return output_dict
121
+
122
+ def extract_images(self):
123
+ raise NotImplementedError("Not implemented or CSVLoader does not support image extraction")
124
+ return
125
+
126
+ def extract_tables(self):
127
+ raise NotImplementedError("Not implemented or CSVLoader does not support table extraction")
128
+ return
129
+
130
+ def extractor(extractor_type: str, *args, **kwargs) -> Extractor:
131
+ """ Function factory to return the appropriate PDF extractor instance, properly initialized """
132
 
133
  if extractor_type == 'PyPDFLoader':
134
  return _PyPDFLoader(*args, **kwargs)
135
 
136
  elif extractor_type == 'LlamaParse':
137
  return _LlamaParse(*args, **kwargs)
138
+
139
+ elif extractor_type == 'txt':
140
+ return _TXTLoader(*args, **kwargs)
141
+
142
  else:
143
  raise ValueError(f"Unsupported PDF extractor type: {extractor_type}")
144
 
145
 
146
 
147
+ #/usr/bin/env /Users/jpb2/Library/Caches/pypoetry/virtualenvs/reflex-Y1r5RCNB-py3.10/bin/python /Users/jpb2/.vscode/extensions/ms-python.debugpy-2024.6.0-darwin-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher 51572 -- -m reflex run --frontend-port 3000 --loglevel debug
148
+ #/usr/bin/env /Volumes/DATA/Dropbox/IMAC_BACKUP/WORK/PROJECTS/INNOVATION/venv/bin/python /Users/jpb2/.vscode/extensions/ms-python.debugpy-2024.6.0-darwin-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher 53961 -- -m reflex run --frontend-port 3001 --loglevel debug --env dev
app/engine/logger.py CHANGED
@@ -1,10 +1,16 @@
1
  import os, logging
 
 
2
 
3
- environment = os.getenv("ENVIRONMENT", "dev")
4
- if environment == "dev":
5
- logger = logging.getLogger("uvicorn")
6
- else:
7
- logger = lambda x: _
8
- # we should log also in production TODO
9
- # check how it works on HuggingFace, if possible
10
- # because we don't have access to the container's file system
 
 
 
 
 
1
  import os, logging
2
+ # import reflex as rx
3
+ logger = logging.getLogger("uvicorn").info
4
 
5
+
6
+ # logger = lambda x: rx.console_log(x)
7
+ # let's use reflex's logger, but doesn't show in the console??
8
+
9
+ # environment = os.getenv("ENVIRONMENT", "dev")
10
+ # if environment == "dev":
11
+ # logger = logging.getLogger("uvicorn").info
12
+ # else:
13
+ # logger = lambda x: print(x)
14
+ # # we should log also in production TODO
15
+ # # check how it works on HuggingFace, if possible
16
+ # # because we don't have access to the container's file system unless in pro mode
app/engine/post_process.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import reflex as rx
2
+ import json
3
+ import requests
4
+ from typing import Optional, List
5
+ from pydantic import BaseModel, Field
6
+ # from rerank import ReRanker
7
+
8
+ # https://hub.guardrailsai.com/validator/guardrails/toxic_language
9
+ from guardrails.hub import ToxicLanguage
10
+ from guardrails import Guard
11
+
12
+ # guardrails hub install hub://guardrails/detect_pii
13
+ from guardrails.hub import DetectPII
14
+
15
+ # https://hub.guardrailsai.com/validator/guardrails/qa_relevance_llm_eval
16
+ from guardrails.hub import QARelevanceLLMEval
17
+
18
+ import logging
19
+ logger = logging.getLogger("uvicorn").info
20
+
21
+ from .summary import summarize_it
22
+
23
+
24
+ def IsPii(answer: str) -> bool:
25
+ guard = Guard().use(DetectPII,
26
+ ["EMAIL_ADDRESS", "PHONE_NUMBER"],
27
+ "exception",
28
+ )
29
+ try:
30
+ guard.validate(answer)
31
+ return True
32
+
33
+ except Exception as e:
34
+ print(e)
35
+ return False
36
+
37
+ def IsToxic(query: str, threshold=0.5) -> bool:
38
+
39
+ # https://hub.guardrailsai.com/validator/guardrails/toxic_language
40
+ # Use the Guard with the validator
41
+ guard = Guard().use(
42
+ ToxicLanguage,
43
+ threshold=threshold, # high for highly toxic only
44
+ validation_method="sentence",
45
+ on_fail="exception"
46
+ )
47
+
48
+ try:
49
+ guard.validate(query)
50
+ return False
51
+
52
+ except Exception as e:
53
+ print(e) # will output the toxic question
54
+ return True
55
+
56
+ def IsRelevant(answer: str, query: str, model: str="gpt-3.5-turbo") -> bool:
57
+
58
+ guard = Guard().use(
59
+ QARelevanceLLMEval,
60
+ llm_callable=model,
61
+ on_fail="exception",
62
+ )
63
+
64
+ try:
65
+ guard.validate(
66
+ answer,
67
+ metadata={"original_prompt": query},
68
+ )
69
+ return True
70
+ except Exception as e:
71
+ print(e)
72
+ return False
73
+
74
+
app/engine/processing.py CHANGED
@@ -1,48 +1,146 @@
1
  import os, pickle
2
  from typing import List
3
- from engine.loaders.file import pdf_extractor
4
- from engine.chunk_embed import chunk_vectorize
5
- from settings import parquet_file
6
  from .logger import logger
7
  from .vectorstore import VectorStore
8
- # I allow relative imports inside the engine package
9
- # I could have created a module but things are still changing
 
 
 
10
 
11
- finrag_vectorstore = VectorStore(model_path='sentence-transformers/all-mpnet-base-v2')
12
-
13
 
14
  def empty_collection():
15
- """ Deletes the Finrag collection if it exists """
16
- status = finrag_vectorstore.empty_collection()
17
  return status
18
 
19
 
20
  def index_data():
21
 
22
  if not os.path.exists(parquet_file):
23
- logger.info(f"Parquet file {parquet_file} does not exists")
24
  return 'no data to index'
25
 
26
  # load the parquet file into the vectorstore
27
- finrag_vectorstore.index_data()
28
  os.remove(parquet_file)
29
  # delete the files so we can load several files and index them when we want
30
  # without having to keep track of those that have been indexed already
31
  # this is a simple solution for now, but we can do better
32
 
33
  return "Index creation successful"
34
-
35
 
36
- def process_pdf(filepath:str) -> dict:
 
37
 
38
- new_content = pdf_extractor('PyPDFLoader', filepath).extract_text()
39
- logger.info(f"Successfully extracted text from PDF")
40
 
41
  chunk_vectorize(new_content)
42
- logger.info(f"Successfully vectorized PDF content")
43
  return new_content
44
 
45
- def vector_search(question:str) -> List[str]:
 
46
 
47
- ans = finrag_vectorstore.hybrid_search(query=question, limit=3, alpha=0.8)
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  return ans
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os, pickle
2
  from typing import List
3
+ from .loaders.file import extractor
4
+ from .chunk_embed import chunk_vectorize
5
+ from ..settings import parquet_file
6
  from .logger import logger
7
  from .vectorstore import VectorStore
8
+ from .post_process import IsPii, IsToxic, IsRelevant
9
+ from .summary import summarize_it
10
+ from .post_process import IsPii, IsToxic, IsRelevant
11
+
12
+ multirag_vectorstore = VectorStore(model_path='sentence-transformers/all-mpnet-base-v2')
13
 
 
 
14
 
15
  def empty_collection():
16
+ """ Deletes the MultiRAG collection if it exists """
17
+ status = multirag_vectorstore.empty_collection()
18
  return status
19
 
20
 
21
  def index_data():
22
 
23
  if not os.path.exists(parquet_file):
24
+ logger(f"Parquet file {parquet_file} does not exists")
25
  return 'no data to index'
26
 
27
  # load the parquet file into the vectorstore
28
+ multirag_vectorstore.index_data()
29
  os.remove(parquet_file)
30
  # delete the files so we can load several files and index them when we want
31
  # without having to keep track of those that have been indexed already
32
  # this is a simple solution for now, but we can do better
33
 
34
  return "Index creation successful"
 
35
 
36
+
37
+ def process_pdf(filepath: str) -> dict:
38
 
39
+ new_content = extractor('PyPDFLoader', filepath).extract_text()
40
+ logger(f"Successfully extracted text from PDF")
41
 
42
  chunk_vectorize(new_content)
43
+ logger(f"Successfully vectorized PDF content of {filepath}")
44
  return new_content
45
 
46
+
47
+ def process_txt(filepath: str) -> dict:
48
 
49
+ new_content = extractor('txt', filepath).extract_text()
50
+ logger(f"Successfully extracted text from TXT")
51
+
52
+ chunk_vectorize(new_content)
53
+ logger(f"Successfully vectorized TXT content")
54
+ return new_content
55
+
56
+
57
+ def vector_search_raw(question: str) -> List[str]:
58
+ """ Just vector search """
59
+ print("WE are in vector_search_raw")
60
+ ans = multirag_vectorstore.hybrid_search(query=question,
61
+ limit=6,
62
+ alpha=0.8)
63
  return ans
64
+
65
+
66
+ def vector_search(question: str, relevance_thr=0.3) -> List[str]:
67
+ """ Search + pre/post processing """
68
+
69
+ ## PRE PROCESSING
70
+ if IsToxic(question):
71
+ ans = [f"\"{question}\" is toxic, try again"]
72
+ return ans
73
+
74
+ ans = multirag_vectorstore.hybrid_search(query=question,
75
+ limit=5,
76
+ alpha=0.8)
77
+
78
+ max_score = max([score for _, _, score in ans])
79
+ # if no answer has a score high enough, we consider the question irrelevant
80
+ # we could do better with reranking but here the question is trivial, y/n
81
+ # it's not like reranking 100 answers to pick the best 5 for RAGing
82
+ if max_score < relevance_thr:
83
+ return [f"{question} is IRRELEVANT with max score: {max_score:.2f}, try again"]
84
+ else:
85
+ answers = [f"{question} is deemed RELEVANT with max score: {max_score:.2f}"]
86
+
87
+ # let's first quickly print the answers, without summary
88
+ for i, (fname, ans, score) in enumerate(ans, 1):
89
+
90
+ if score < relevance_thr:
91
+ continue
92
+
93
+ if IsPii(ans):
94
+ ans = " Pii detected -" + ans
95
+
96
+ # removed, not accurate
97
+ if IsRelevant(ans, question):
98
+ relevant = 'RELEVANT'
99
+ else:
100
+ # irrelevant answer
101
+ relevant = 'IRRELEVANT'
102
+
103
+ summary = summarize_it(question, [ans])
104
+ ans = f"{ans}\n SUMMARY: {summary}"
105
+
106
+ answers.append(f"{i}: from {fname} - score:{score:.2f} - {relevant} answer - {ans}")
107
+
108
+ # msg = f"Answers to '{self.question}' with summaries"
109
+ # self.chats[self.current_chat] = [qa1]
110
+
111
+ # for i, (fname, ans, score) in enumerate(self.answer['answer'], 1):
112
+
113
+ # if score < relevance_thr:
114
+ # continue
115
+
116
+ # msg = ""
117
+ # summary = summarize_it(self.question, [ans])
118
+
119
+ # # if IsPii(ans):
120
+ # # qa.answer += " Pii detected -"
121
+
122
+ # # removed, not accurate
123
+ # # if IsRelevant(ans, self.question):
124
+ # # relevant = 'RELEVANT'
125
+ # # else:
126
+ # # # irrelevant answer
127
+ # # relevant = 'IRRELEVANT'
128
+ # # qa.answer += f" {relevant} ANSWER - {ans} \n SUMMARY: {summary}"
129
+
130
+ # qa = QA(question=msg,
131
+ # answer=f"{i}: from {fname} - score:{score:.2f} - {ans} - SUMMARY: {summary}"
132
+ # )
133
+
134
+ # # paths are from /assets, so data is assets/data
135
+ # search = ans[:30].replace(" ", "%20") # let's search only first 30 chars
136
+ # qa.link = f'data/{fname}#:~:text={search}'
137
+ # qa.msg = " - Verify in the document"
138
+ # logger(f"Summary: {summary}")
139
+
140
+ # # it's slower now because of the summaries
141
+ # self.chats[self.current_chat].append(qa)
142
+ # yield
143
+
144
+ # msg = ""
145
+
146
+ return answers
app/engine/summary.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import List
3
+
4
+ from app.rag.llm import LLM
5
+ #the LLM Class uses the OPENAI_API_KEY env var as the default api_key
6
+
7
+
8
+ def summarize_it(question: str,
9
+ search_results: List[str],
10
+ model: str = 'gpt-3.5-turbo-0125',
11
+ ) -> str:
12
+
13
+ # TODO turn this into a class if time allows
14
+ llm = LLM(model)
15
+
16
+ system_message = """
17
+ You are able to quickly understand a few paragraphs, or quips even, generated by vector search system
18
+ and generate a one-line summary.
19
+ """
20
+
21
+ searches = "\n".join([f"Search result {i}: {v}" for i,v in enumerate(search_results,1)])
22
+
23
+ user_prompt = f"""
24
+ Use the below context enclosed in triple back ticks to answer the question. \n
25
+ The context is given by a vector search into a vector database made from the company's documents,
26
+ so you can assume the context is accurate. \n
27
+ ```
28
+ Context:
29
+ ```
30
+ {searches}
31
+ ```
32
+ Question:\n
33
+ {question}\n
34
+ ------------------------
35
+ 1. If the context is not relevant to the question, simply say 'Irrelevant content' and nothing else.
36
+ Pay great attention to making sure your answer is relevant to the question and the context.
37
+ (for instance, never answer a question about a topic that is not explicitely mentioned in the question)
38
+ 2. Using any external knowledge or resources to answer the question is forbidden.
39
+ 3. Generate a ONE-LINE ONE-LINE summary within the limits of the context and the question.
40
+ 4. Avoid mentioning 'search results' in the answer.
41
+ Instead, incorporate the information from the search results into the answer.
42
+ 5. Create a clean answer, without backticks, or starting with a new line for instance.
43
+ ------------------------
44
+ Answer:\n
45
+ """.format(searches=searches, question=question)
46
+
47
+
48
+ response = llm.chat_completion(system_message=system_message,
49
+ user_message=user_prompt,
50
+ temperature=0.01, # let's not allow the model to be creative
51
+ stream=False,
52
+ raw_response=False)
53
+ return response
app/engine/vectorstore.py CHANGED
@@ -1,19 +1,109 @@
1
  import os, logging
 
 
2
  from typing import List, Any
3
  import pandas as pd
4
  from weaviate.classes.config import Property, DataType
5
 
6
  from .weaviate_interface_v4 import WeaviateWCS, WeaviateIndexer
7
- from .logger import logger
8
 
9
- from settings import parquet_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  class VectorStore:
12
- def __init__(self, model_path:str = 'sentence-transformers/all-mpnet-base-v2'):
13
  # we can create several instances to test various models, especially if we finetune one
14
 
15
- self.finrag_properties = [
16
- Property(name='filename',
17
  data_type=DataType.TEXT,
18
  description='Name of the file',
19
  index_filterable=True,
@@ -30,45 +120,54 @@ class VectorStore:
30
  index_searchable=True),
31
  ]
32
 
33
- self.class_name = "FinRag_all-mpnet-base-v2"
34
 
35
  self.class_config = {'classes': [
36
 
37
  {"class": self.class_name,
38
 
39
- "description": "Financial reports",
40
 
41
  "vectorIndexType": "hnsw",
42
 
43
- # Vector index specific settings for HSNW
44
  "vectorIndexConfig": {
45
 
46
  "ef": 64, # higher is better quality vs slower search
47
  "efConstruction": 128, # higher = better index but slower build
48
  "maxConnections": 32, # max conn per layer - higher = more memory
49
  },
50
-
51
  "vectorizer": "none",
52
-
53
- "properties": self.finrag_properties }
54
  ]
55
  }
56
 
57
  self.model_path = model_path
58
-
59
  try:
60
  self.api_key = os.environ.get('FINRAG_WEAVIATE_API_KEY')
61
- self.url = os.environ.get('FINRAG_WEAVIATE_ENDPOINT')
62
- self.client = WeaviateWCS(endpoint=self.url,
63
- api_key=self.api_key,
64
- model_name_or_path=self.model_path)
65
-
 
 
 
 
 
 
66
  except Exception as e:
67
  # raise Exception(f"Could not create Weaviate client: {e}")
68
- print(f"Could not create Weaviate client: {e}")
69
-
70
- assert self.client._client.is_live(), "Weaviate is not live"
71
- assert self.client._client.is_ready(), "Weaviate is not ready"
 
 
 
72
  # careful with accessing '_client' since the weaviate helper usually closes the connection every time
73
 
74
  self.indexer = None
@@ -80,19 +179,21 @@ class VectorStore:
80
 
81
  return self.client.show_all_collections()
82
 
83
- def create_collection(self, collection_name: str='Finrag', description: str='Financial reports'):
 
 
84
 
85
  self.collection_name = collection_name
86
  if collection_name not in self.collections:
87
  self.client.create_collection(collection_name=collection_name,
88
- properties=self.finrag_properties,
89
  description=description)
90
- self.collection_name = collection_name
91
  else:
92
- logging.warning(f"Collection {collection_name} already exists")
93
 
94
 
95
- def empty_collection(self, collection_name: str='Finrag') -> bool:
96
 
97
  # not in the library yet, so I simply delete and recreate it
98
  if collection_name in self.collections:
@@ -100,11 +201,11 @@ class VectorStore:
100
  self.create_collection()
101
  return True
102
  else:
103
- logging.warning(f"Collection {collection_name} doesn't exist")
104
  return False
105
 
106
 
107
- def index_data(self, data: List[dict]= None, collection_name: str='Finrag'):
108
 
109
  if self.indexer is None:
110
  self.indexer = WeaviateIndexer(self.client)
@@ -127,25 +228,25 @@ class VectorStore:
127
  def keyword_search(self,
128
  query: str,
129
  limit: int=5,
130
- return_properties: List[str]=['filename', 'content'],
131
  alpha=None # dummy parameter to match the hybrid_search signature
132
  ) -> List[str]:
133
  response = self.client.keyword_search(
134
  request=query,
135
  collection_name=self.collection_name,
136
- query_properties=['content'],
137
  limit=limit,
138
  filter=None,
139
  return_properties=return_properties,
140
  return_raw=False)
141
 
142
- return [res['content'] for res in response]
143
 
144
 
145
  def vector_search(self,
146
  query: str,
147
  limit: int=5,
148
- return_properties: List[str]=['filename', 'content'],
149
  alpha=None # dummy parameter to match the hybrid_search signature
150
  ) -> List[str]:
151
 
@@ -157,24 +258,24 @@ class VectorStore:
157
  return_properties=return_properties,
158
  return_raw=False)
159
 
160
- return [res['content'] for res in response]
161
 
162
 
163
  def hybrid_search(self,
164
  query: str,
165
- limit: int=5,
166
  alpha=0.5, # higher = more vector search
167
- return_properties: List[str]=['filename', 'content']
168
  ) -> List[str]:
169
-
170
  response = self.client.hybrid_search(
171
  request=query,
172
  collection_name=self.collection_name,
173
- query_properties=['content'],
174
  alpha=alpha,
175
  limit=limit,
176
  filter=None,
177
  return_properties=return_properties,
178
  return_raw=False)
179
 
180
- return [res['content'] for res in response]
 
1
  import os, logging
2
+ from app.engine.logger import logger
3
+
4
  from typing import List, Any
5
  import pandas as pd
6
  from weaviate.classes.config import Property, DataType
7
 
8
  from .weaviate_interface_v4 import WeaviateWCS, WeaviateIndexer
 
9
 
10
+ from app.settings import parquet_file
11
+ from weaviate.classes.query import Filter
12
+ from torch import cuda
13
+
14
+ if os.path.exists('.we_are_local'):
15
+ COLLECTION = 'MultiRAG_local'
16
+ else:
17
+ COLLECTION = 'MultiRAG'
18
+
19
+ class dummyWeaviate:
20
+ """ Created to pass on HF since I had again the client creation issue
21
+ Temporary solution
22
+ """
23
+ def __init__(self,
24
+ endpoint: str=None,
25
+ api_key: str=None,
26
+ model_name_or_path: str='sentence-transformers/all-MiniLM-L6-v2',
27
+ embedded: bool=False,
28
+ openai_api_key: str=None,
29
+ skip_init_checks: bool=False,
30
+ **kwargs
31
+ ):
32
+ return
33
+
34
+ def _connect(self) -> None:
35
+ return
36
+
37
+ def _client(self):
38
+ return
39
+
40
+ def create_collection(self,
41
+ collection_name: str,
42
+ properties: list[Property],
43
+ description: str=None,
44
+ **kwargs
45
+ ) -> None:
46
+ return
47
+
48
+ def show_all_collections(self,
49
+ detailed: bool=False,
50
+ max_details: bool=False
51
+ ) -> list[str] | dict:
52
+ return ['abc', 'def']
53
+
54
+ def show_collection_config(self, collection_name: str):
55
+ return
56
+
57
+ def show_collection_properties(self, collection_name: str):
58
+ return
59
+
60
+ def delete_collection(self, collection_name: str):
61
+ return
62
+
63
+ def get_doc_count(self, collection_name: str):
64
+ return
65
+
66
+ def keyword_search(self,
67
+ request: str,
68
+ collection_name: str,
69
+ query_properties: list[str]=['content'],
70
+ limit: int=10,
71
+ filter: Filter=None,
72
+ return_properties: list[str]=None,
73
+ return_raw: bool=False
74
+ ):
75
+ return
76
+
77
+ def vector_search(self,
78
+ request: str,
79
+ collection_name: str,
80
+ limit: int=10,
81
+ return_properties: list[str]=None,
82
+ filter: Filter=None,
83
+ return_raw: bool=False,
84
+ device: str='cuda:0' if cuda.is_available() else 'cpu'
85
+ ):
86
+ return
87
+
88
+ def hybrid_search(self,
89
+ request: str,
90
+ collection_name: str,
91
+ query_properties: list[str]=['content'],
92
+ alpha: float=0.5,
93
+ limit: int=10,
94
+ filter: Filter=None,
95
+ return_properties: list[str]=None,
96
+ return_raw: bool=False,
97
+ device: str='cuda:0' if cuda.is_available() else 'cpu'
98
+ ):
99
+ return
100
 
101
  class VectorStore:
102
+ def __init__(self, model_path: str = 'sentence-transformers/all-mpnet-base-v2'):
103
  # we can create several instances to test various models, especially if we finetune one
104
 
105
+ self.MultiRAG_properties = [
106
+ Property(name='file',
107
  data_type=DataType.TEXT,
108
  description='Name of the file',
109
  index_filterable=True,
 
120
  index_searchable=True),
121
  ]
122
 
123
+ self.class_name = "MultiRAG_all-mpnet-base-v2"
124
 
125
  self.class_config = {'classes': [
126
 
127
  {"class": self.class_name,
128
 
129
+ "description": "multiple types of docs",
130
 
131
  "vectorIndexType": "hnsw",
132
 
133
+ # Vector index specific app.settings for HSNW
134
  "vectorIndexConfig": {
135
 
136
  "ef": 64, # higher is better quality vs slower search
137
  "efConstruction": 128, # higher = better index but slower build
138
  "maxConnections": 32, # max conn per layer - higher = more memory
139
  },
140
+
141
  "vectorizer": "none",
142
+
143
+ "properties": self.MultiRAG_properties}
144
  ]
145
  }
146
 
147
  self.model_path = model_path
148
+
149
  try:
150
  self.api_key = os.environ.get('FINRAG_WEAVIATE_API_KEY')
151
+ logger(f"API key: {self.api_key[:5]}")
152
+ self.url = os.environ.get('FINRAG_WEAVIATE_ENDPOINT')
153
+ logger(f"URL: {self.url[8:15]}")
154
+ self.client = WeaviateWCS(
155
+ endpoint=self.url,
156
+ api_key=self.api_key,
157
+ model_name_or_path=self.model_path,
158
+ )
159
+ assert self.client._client.is_live(), "Weaviate is not live"
160
+ assert self.client._client.is_ready(), "Weaviate is not ready"
161
+ logger(f"Weaviate client created")
162
  except Exception as e:
163
  # raise Exception(f"Could not create Weaviate client: {e}")
164
+ self.client = dummyWeaviate() # used when issue with HF client creation, to continue on HF
165
+ logger(f"Could not create Weaviate client: {e}")
166
+
167
+ # if we fail these tests 'VectorStore' object has no attribute 'client'
168
+ # it's prob not the env var but the model missing
169
+ # assert self.client._client.is_live(), "Weaviate is not live"
170
+ # assert self.client._client.is_ready(), "Weaviate is not ready"
171
  # careful with accessing '_client' since the weaviate helper usually closes the connection every time
172
 
173
  self.indexer = None
 
179
 
180
  return self.client.show_all_collections()
181
 
182
+ def create_collection(self,
183
+ collection_name: str=COLLECTION,
184
+ description: str='Documents'):
185
 
186
  self.collection_name = collection_name
187
  if collection_name not in self.collections:
188
  self.client.create_collection(collection_name=collection_name,
189
+ properties=self.MultiRAG_properties,
190
  description=description)
191
+ # self.collection_name = collection_name
192
  else:
193
+ logger(f"Collection {collection_name} already exists")
194
 
195
 
196
+ def empty_collection(self, collection_name: str=COLLECTION) -> bool:
197
 
198
  # not in the library yet, so I simply delete and recreate it
199
  if collection_name in self.collections:
 
201
  self.create_collection()
202
  return True
203
  else:
204
+ logger(f"Collection {collection_name} doesn't exist")
205
  return False
206
 
207
 
208
+ def index_data(self, data: List[dict]= None, collection_name: str=COLLECTION):
209
 
210
  if self.indexer is None:
211
  self.indexer = WeaviateIndexer(self.client)
 
228
  def keyword_search(self,
229
  query: str,
230
  limit: int=5,
231
+ return_properties: List[str]=['file', 'content'],
232
  alpha=None # dummy parameter to match the hybrid_search signature
233
  ) -> List[str]:
234
  response = self.client.keyword_search(
235
  request=query,
236
  collection_name=self.collection_name,
237
+ query_properties=['file', 'content'],
238
  limit=limit,
239
  filter=None,
240
  return_properties=return_properties,
241
  return_raw=False)
242
 
243
+ return [(res['file'], res['content'], res['score']) for res in response]
244
 
245
 
246
  def vector_search(self,
247
  query: str,
248
  limit: int=5,
249
+ return_properties: List[str]=['file', 'content'],
250
  alpha=None # dummy parameter to match the hybrid_search signature
251
  ) -> List[str]:
252
 
 
258
  return_properties=return_properties,
259
  return_raw=False)
260
 
261
+ return [(res['file'], res['content'], res['score']) for res in response]
262
 
263
 
264
  def hybrid_search(self,
265
  query: str,
266
+ limit: int=10,
267
  alpha=0.5, # higher = more vector search
268
+ return_properties: List[str]=['file', 'content']
269
  ) -> List[str]:
270
+ print("We are in hybrid_search")
271
  response = self.client.hybrid_search(
272
  request=query,
273
  collection_name=self.collection_name,
274
+ query_properties=['file', 'content'],
275
  alpha=alpha,
276
  limit=limit,
277
  filter=None,
278
  return_properties=return_properties,
279
  return_raw=False)
280
 
281
+ return [(res['file'], res['content'], res['score']) for res in response]
app/engine/weaviate_interface_v4.py CHANGED
@@ -343,9 +343,12 @@ class WeaviateWCS:
343
  If True, returns raw response from Weaviate.
344
  '''
345
  self._connect()
 
346
  return_properties = return_properties if return_properties else self.return_properties
347
  query_vector = self._create_query_vector(request, device=device)
 
348
  collection = self._client.collections.get(collection_name)
 
349
  response = collection.query.hybrid(query=request,
350
  query_properties=query_properties,
351
  filters=filter,
@@ -354,6 +357,7 @@ class WeaviateWCS:
354
  limit=limit,
355
  return_metadata=MetadataQuery(score=True, distance=True),
356
  return_properties=return_properties)
 
357
  if return_raw:
358
  return response
359
  else:
 
343
  If True, returns raw response from Weaviate.
344
  '''
345
  self._connect()
346
+ print("We are connected to Weaviate")
347
  return_properties = return_properties if return_properties else self.return_properties
348
  query_vector = self._create_query_vector(request, device=device)
349
+ print("After query vector")
350
  collection = self._client.collections.get(collection_name)
351
+ print("Just before query")
352
  response = collection.query.hybrid(query=request,
353
  query_properties=query_properties,
354
  filters=filter,
 
357
  limit=limit,
358
  return_metadata=MetadataQuery(score=True, distance=True),
359
  return_properties=return_properties)
360
+ print("After Weaviate response")
361
  if return_raw:
362
  return response
363
  else:
app/main.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
 
2
  import os, random, logging, pickle, shutil
3
  from dotenv import load_dotenv, find_dotenv
@@ -8,65 +10,83 @@ from fastapi import FastAPI, HTTPException, File, UploadFile, status
8
  from fastapi.responses import HTMLResponse
9
  from fastapi.middleware.cors import CORSMiddleware
10
 
11
- from engine.processing import process_pdf, index_data, empty_collection, vector_search
12
- from rag.rag import rag_it
 
 
 
13
 
14
- from engine.logger import logger
 
 
 
 
 
 
 
 
15
 
16
- from settings import datadir
17
 
18
- os.makedirs(datadir, exist_ok=True)
 
 
19
 
20
  app = FastAPI()
21
 
22
  environment = os.getenv("ENVIRONMENT", "dev") # created by dockerfile
23
 
24
- if environment == "dev":
25
- logger.warning("Running in development mode - allowing CORS for all origins")
26
- app.add_middleware(
27
- CORSMiddleware,
28
- allow_origins=["*"],
29
- allow_credentials=True,
30
- allow_methods=["*"],
31
- allow_headers=["*"],
32
- )
33
-
34
- try:
35
- load_dotenv(find_dotenv('env'))
36
-
37
- except Exception as e:
38
- pass
39
 
40
 
 
41
  @app.get("/", response_class=HTMLResponse)
42
  def read_root():
43
- logger.info("Title displayed on home page")
44
  return """
45
  <html>
46
  <body>
47
- <h1>Welcome to FinExpert, a RAG system designed by JP Bianchi!</h1>
48
  </body>
49
  </html>
50
  """
51
 
52
-
53
  @app.get("/ping/")
54
  def ping():
55
  """ Testing """
56
- logger.info("Someone is pinging the server")
57
  return {"answer": str(int(random.random() * 100))}
58
 
59
 
60
  @app.delete("/erase_data/")
61
  def erase_data():
62
- """ Erase all files in the data directory, but not the vector store """
 
 
 
 
63
  if len(os.listdir(datadir)) == 0:
64
- logger.info("No data to erase")
65
  return {"message": "No data to erase"}
66
 
67
- shutil.rmtree(datadir, ignore_errors=True)
68
- os.mkdir(datadir)
69
- logger.warning("All data has been erased")
 
 
 
 
 
70
  return {"message": "All data has been erased"}
71
 
72
 
@@ -75,15 +95,17 @@ def delete_vectors():
75
  """ Empty the collection in the vector store """
76
  try:
77
  status = empty_collection()
78
- return {f"""message": "Collection{'' if status else ' NOT'} erased!"""}
79
  except Exception as e:
80
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
81
 
 
82
  @app.get("/list_files/")
83
  def list_files():
84
  """ List all files in the data directory """
 
85
  files = os.listdir(datadir)
86
- logger.info(f"Files in data directory: {files}")
87
  return {"files": files}
88
 
89
 
@@ -93,18 +115,18 @@ async def upload_file(file: UploadFile = File(...)):
93
  """ Uploads a file in data directory, for later indexing """
94
  try:
95
  filepath = os.path.join(datadir, file.filename)
96
- logger.info(f"Fiename detected: {file.filename}")
97
  if os.path.exists(filepath):
98
- logger.warning(f"File {file.filename} already exists: no processing done")
99
  return {"message": f"File {file.filename} already exists: no processing done"}
100
 
101
  else:
102
- logger.info(f"Receiving file: {file.filename}")
103
  contents = await file.read()
104
- logger.info(f"File reception complete!")
105
 
106
  except Exception as e:
107
- logger.error(f"Error during file upload: {str(e)}")
108
  return {"message": f"Error during file upload: {str(e)}"}
109
 
110
  if file.filename.endswith('.pdf'):
@@ -112,9 +134,14 @@ async def upload_file(file: UploadFile = File(...)):
112
  # let's save the file in /data even if it's temp storage on HF
113
  with open(filepath, 'wb') as f:
114
  f.write(contents)
 
 
 
 
 
115
 
116
  try:
117
- logger.info(f"Starting to process {file.filename}")
118
  new_content = process_pdf(filepath)
119
  success = {"message": f"Successfully uploaded {file.filename}"}
120
  success.update(new_content)
@@ -122,15 +149,35 @@ async def upload_file(file: UploadFile = File(...)):
122
 
123
  except Exception as e:
124
  return {"message": f"Failed to extract text from PDF: {str(e)}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  else:
126
- return {"message": "Only PDF files are accepted"}
127
 
128
 
129
  @app.post("/create_index/")
130
  async def create_index():
131
  """ Create an index for the uploaded files """
132
 
133
- logger.info("Creating index for uploaded files")
134
  try:
135
  msg = index_data()
136
  return {"message": msg}
@@ -143,29 +190,30 @@ class Question(BaseModel):
143
 
144
  @app.post("/ask/")
145
  async def hybrid_search(question: Question):
146
- logger.info(f"Processing question: {question.question}")
147
  try:
148
  search_results = vector_search(question.question)
149
- logger.info(f"Answer: {search_results}")
150
  return {"answer": search_results}
151
  except Exception as e:
152
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
153
 
154
-
155
  @app.post("/ragit/")
156
  async def ragit(question: Question):
157
- logger.info(f"Processing question: {question.question}")
158
  try:
159
- search_results = vector_search(question.question)
160
- logger.info(f"Search results generated: {search_results}")
161
 
162
  answer = rag_it(question.question, search_results)
163
 
164
- logger.info(f"Answer: {answer}")
165
  return {"answer": answer}
166
  except Exception as e:
167
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
168
 
 
169
  if __name__ == '__main__':
170
  import uvicorn
171
  from os import getenv
@@ -175,16 +223,16 @@ if __name__ == '__main__':
175
  uvicorn.run("main:app", host="0.0.0.0", port=port, reload=reload)
176
 
177
 
178
-
179
  # Examples:
180
- # curl -X POST "http://localhost:80/upload" -F "file=@test.pdf"
181
- # curl -X DELETE "http://localhost:80/erase_data/"
182
- # curl -X GET "http://localhost:80/list_files/"
183
 
184
- # hf space is at https://jpbianchi-finrag.hf.space/
185
- # code given by https://jpbianchi-finrag.hf.space/docs
186
  # Space must be public
187
- # curl -X POST "https://jpbianchi-finrag.hf.space/upload/" -F "file=@test.pdf"
188
 
189
  # curl -X POST http://localhost:80/ask/ -H "Content-Type: application/json" -d '{"question": "what is Amazon loss"}'
190
  # curl -X POST http://localhost:80/ragit/ -H "Content-Type: application/json" -d '{"question": "Does ATT have postpaid phone customers?"}'
 
 
1
+ # this is the original main.py file, but without the call to fastapi
2
+ # since it is done by reflex's own fast api server
3
 
4
  import os, random, logging, pickle, shutil
5
  from dotenv import load_dotenv, find_dotenv
 
10
  from fastapi.responses import HTMLResponse
11
  from fastapi.middleware.cors import CORSMiddleware
12
 
13
+ try:
14
+ load_dotenv(find_dotenv('env'))
15
+
16
+ except Exception as e:
17
+ pass
18
 
19
+ from app.engine.processing import ( # << creates the collection already
20
+ process_pdf,
21
+ process_txt,
22
+ index_data,
23
+ empty_collection,
24
+ vector_search,
25
+ vector_search_raw
26
+ )
27
+ from app.rag.rag import rag_it
28
 
29
+ from app.engine.logger import logger
30
 
31
+ from app.settings import datadir, datadir2
32
+
33
+ EXTENSIONS = ["pdf", "txt"]
34
 
35
  app = FastAPI()
36
 
37
  environment = os.getenv("ENVIRONMENT", "dev") # created by dockerfile
38
 
39
+ # replaced by cors_allowed_origins=['*'] in rxconfig.py when using Reflex endpoint
40
+ # if environment == "dev":
41
+ # logger("Running in development mode - allowing CORS for all origins")
42
+ # app.add_middleware(
43
+ # CORSMiddleware,
44
+ # allow_origins=["*"],
45
+ # allow_credentials=True,
46
+ # allow_methods=["*"],
47
+ # allow_headers=["*"],
48
+ # )
 
 
 
 
 
49
 
50
 
51
+ # not used when using Reflex endpoint
52
  @app.get("/", response_class=HTMLResponse)
53
  def read_root():
54
+ logger("Title displayed on home page")
55
  return """
56
  <html>
57
  <body>
58
+ <h1>Welcome to MultiRAG, a RAG system designed by JP Bianchi!</h1>
59
  </body>
60
  </html>
61
  """
62
 
63
+ # already provided by Reflex
64
  @app.get("/ping/")
65
  def ping():
66
  """ Testing """
67
+ logger("Someone is pinging the server")
68
  return {"answer": str(int(random.random() * 100))}
69
 
70
 
71
  @app.delete("/erase_data/")
72
  def erase_data():
73
+ """ Erase all files in the data directory at the first level only,
74
+ (in case we would like to use it for something else)
75
+ but not the vector store or the parquet file.
76
+ We can do it since the embeddings are in the parquet file already.
77
+ """
78
  if len(os.listdir(datadir)) == 0:
79
+ logger("No data to erase")
80
  return {"message": "No data to erase"}
81
 
82
+ # if we try to rmtree datadir, it looks like /data can't be deleted on HF
83
+ for f in os.listdir(datadir):
84
+ if f == '.DS_Store' or f.split('.')[-1].lower() in EXTENSIONS:
85
+ print(f"Removing {f}")
86
+ os.remove(os.path.join(datadir, f))
87
+ # we don't remove the parquet file, create_index does that
88
+
89
+ logger("All data has been erased")
90
  return {"message": "All data has been erased"}
91
 
92
 
 
95
  """ Empty the collection in the vector store """
96
  try:
97
  status = empty_collection()
98
+ return {"message": f"Collection{'' if status else ' NOT'} erased!"}
99
  except Exception as e:
100
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
101
 
102
+
103
  @app.get("/list_files/")
104
  def list_files():
105
  """ List all files in the data directory """
106
+ print("Listing files")
107
  files = os.listdir(datadir)
108
+ logger(f"Files in data directory: {files}")
109
  return {"files": files}
110
 
111
 
 
115
  """ Uploads a file in data directory, for later indexing """
116
  try:
117
  filepath = os.path.join(datadir, file.filename)
118
+ logger(f"Fiename detected: {file.filename}")
119
  if os.path.exists(filepath):
120
+ logger(f"File {file.filename} already exists: no processing done")
121
  return {"message": f"File {file.filename} already exists: no processing done"}
122
 
123
  else:
124
+ logger(f"Receiving file: {file.filename}")
125
  contents = await file.read()
126
+ logger(f"File reception complete!")
127
 
128
  except Exception as e:
129
+ logger(f"Error during file upload: {str(e)}")
130
  return {"message": f"Error during file upload: {str(e)}"}
131
 
132
  if file.filename.endswith('.pdf'):
 
134
  # let's save the file in /data even if it's temp storage on HF
135
  with open(filepath, 'wb') as f:
136
  f.write(contents)
137
+
138
+ # save it also in assets/data because data can be cleared
139
+ filepath2 = os.path.join(datadir2, file.filename)
140
+ with open(filepath2, 'wb') as f:
141
+ f.write(contents)
142
 
143
  try:
144
+ logger(f"Starting to process {file.filename}")
145
  new_content = process_pdf(filepath)
146
  success = {"message": f"Successfully uploaded {file.filename}"}
147
  success.update(new_content)
 
149
 
150
  except Exception as e:
151
  return {"message": f"Failed to extract text from PDF: {str(e)}"}
152
+
153
+ elif file.filename.endswith('.txt'):
154
+
155
+ with open(filepath, 'wb') as f:
156
+ f.write(contents)
157
+
158
+ filepath2 = os.path.join(datadir2, file.filename)
159
+ with open(filepath2, 'wb') as f:
160
+ f.write(contents)
161
+
162
+ try:
163
+ logger(f"Reading {file.filename}")
164
+ new_content = process_txt(filepath)
165
+ success = {"message": f"Successfully uploaded {file.filename}"}
166
+ success.update(new_content)
167
+ return success
168
+
169
+ except Exception as e:
170
+ return {"message": f"Failed to extract text from TXT: {str(e)}"}
171
+
172
  else:
173
+ return {"message": "Only PDF & txt files are accepted"}
174
 
175
 
176
  @app.post("/create_index/")
177
  async def create_index():
178
  """ Create an index for the uploaded files """
179
 
180
+ logger("Creating index for uploaded files")
181
  try:
182
  msg = index_data()
183
  return {"message": msg}
 
190
 
191
  @app.post("/ask/")
192
  async def hybrid_search(question: Question):
193
+ logger(f"Processing question: {question.question}")
194
  try:
195
  search_results = vector_search(question.question)
196
+ logger(f"Answer: {search_results}")
197
  return {"answer": search_results}
198
  except Exception as e:
199
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
200
 
201
+
202
  @app.post("/ragit/")
203
  async def ragit(question: Question):
204
+ logger(f"Processing question: {question.question}")
205
  try:
206
+ search_results = vector_search_raw(question.question)
207
+ logger(f"Search results generated: {search_results}")
208
 
209
  answer = rag_it(question.question, search_results)
210
 
211
+ logger(f"Answer: {answer}")
212
  return {"answer": answer}
213
  except Exception as e:
214
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
215
 
216
+
217
  if __name__ == '__main__':
218
  import uvicorn
219
  from os import getenv
 
223
  uvicorn.run("main:app", host="0.0.0.0", port=port, reload=reload)
224
 
225
 
 
226
  # Examples:
227
+ # curl -X POST "http://localhost:8001/upload" -F "file=@test.pdf"
228
+ # curl -X DELETE "http://localhost:8001/erase_data/"
229
+ # curl -X GET "http://localhost:8001/list_files/"
230
 
231
+ # hf space is at https://jpbianchi-multirag.hf.space/
232
+ # code given by https://jpbianchi-multirag.hf.space/docs
233
  # Space must be public
234
+ # curl -X POST "https://jpbianchi-multirag.hf.space/upload/" -F "file=@test.pdf"
235
 
236
  # curl -X POST http://localhost:80/ask/ -H "Content-Type: application/json" -d '{"question": "what is Amazon loss"}'
237
  # curl -X POST http://localhost:80/ragit/ -H "Content-Type: application/json" -d '{"question": "Does ATT have postpaid phone customers?"}'
238
+ # see more in notebook upload_index.ipynb
app/main_reflex.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this is the original main.py file, but without the call to fastapi
2
+ # since it is done by reflex's own fast api server
3
+
4
+ import os, random, logging, pickle, shutil
5
+ from dotenv import load_dotenv, find_dotenv
6
+ from typing import Optional
7
+ from pydantic import BaseModel, Field
8
+
9
+ from fastapi import FastAPI, HTTPException, File, UploadFile, status
10
+ from fastapi.responses import HTMLResponse
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from app.engine.processing import ( # << creates the collection already
13
+ process_pdf,
14
+ process_txt,
15
+ index_data,
16
+ empty_collection,
17
+ vector_search,
18
+ vector_search_raw,
19
+ )
20
+ from app.rag.rag import rag_it
21
+
22
+ from app.engine.logger import logger
23
+
24
+ from app.settings import datadir, datadir2
25
+
26
+ EXTENSIONS = ["pdf", "txt"]
27
+
28
+ # app = FastAPI()
29
+
30
+ environment = os.getenv("ENVIRONMENT", "dev") # created by dockerfile
31
+
32
+ # replaced by cors_allowed_origins=['*'] in rxconfig.py when using Reflex endpoint
33
+ # if environment == "dev":
34
+ # logger("Running in development mode - allowing CORS for all origins")
35
+ # app.add_middleware(
36
+ # CORSMiddleware,
37
+ # allow_origins=["*"],
38
+ # allow_credentials=True,
39
+ # allow_methods=["*"],
40
+ # allow_headers=["*"],
41
+ # )
42
+
43
+
44
+ # not used when using Reflex endpoint
45
+ # @app.get("/", response_class=HTMLResponse)
46
+ def read_root():
47
+ logger("Title displayed on home page")
48
+ return """
49
+ <html>
50
+ <body>
51
+ <h1>Welcome to MultiRAG, a RAG system designed by JP Bianchi!</h1>
52
+ </body>
53
+ </html>
54
+ """
55
+
56
+ # already provided by Reflex
57
+ # @app.get("/ping/")
58
+ def ping():
59
+ """ Testing """
60
+ logger("Someone is pinging the server")
61
+ return {"answer": str(int(random.random() * 100))}
62
+
63
+
64
+ # @app.delete("/erase_data/")
65
+ def erase_data():
66
+ """ Erase all files in the data directory at the first level only,
67
+ (in case we would like to use it for something else)
68
+ but not the vector store or the parquet file.
69
+ We can do it since the embeddings are in the parquet file already.
70
+ """
71
+ if len(os.listdir(datadir)) == 0:
72
+ logger("No data to erase")
73
+ return {"message": "No data to erase"}
74
+
75
+ # if we try to rmtree datadir, it looks like /data can't be deleted on HF
76
+ for f in os.listdir(datadir):
77
+ if f == '.DS_Store' or f.split('.')[-1].lower() in EXTENSIONS:
78
+ print(f"Removing {f}")
79
+ os.remove(os.path.join(datadir, f))
80
+ # we don't remove the parquet file, create_index does that
81
+
82
+ logger("All data has been erased")
83
+ return {"message": "All data has been erased"}
84
+
85
+
86
+ # @app.delete("/empty_collection/")
87
+ def delete_vectors():
88
+ """ Empty the collection in the vector store """
89
+ try:
90
+ status = empty_collection()
91
+ return {"message": f"Collection{'' if status else ' NOT'} erased!"}
92
+ except Exception as e:
93
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
94
+
95
+
96
+ # @app.get("/list_files/")
97
+ def list_files():
98
+ """ List all files in the data directory """
99
+ print("Listing files")
100
+ files = os.listdir(datadir)
101
+ logger(f"Files in data directory: {files}")
102
+ return {"files": files}
103
+
104
+
105
+ # @app.post("/upload/")
106
+ # @limiter.limit("5/minute") see 'slowapi' for rate limiting
107
+ async def upload_file(file: UploadFile = File(...)):
108
+ """ Uploads a file in data directory, for later indexing """
109
+ try:
110
+ filepath = os.path.join(datadir, file.filename)
111
+ logger(f"Fiename detected: {file.filename}")
112
+ if os.path.exists(filepath):
113
+ logger(f"File {file.filename} already exists: no processing done")
114
+ return {"message": f"File {file.filename} already exists: no processing done"}
115
+
116
+ else:
117
+ logger(f"Receiving file: {file.filename}")
118
+ contents = await file.read()
119
+ logger(f"File reception complete!")
120
+
121
+ except Exception as e:
122
+ logger(f"Error during file upload: {str(e)}")
123
+ return {"message": f"Error during file upload: {str(e)}"}
124
+
125
+ if file.filename.endswith('.pdf'):
126
+
127
+ # let's save the file in /data even if it's temp storage on HF
128
+ with open(filepath, 'wb') as f:
129
+ f.write(contents)
130
+
131
+ # save it also in assets/data because data can be cleared
132
+ filepath2 = os.path.join(datadir2, file.filename)
133
+ with open(filepath2, 'wb') as f:
134
+ f.write(contents)
135
+
136
+ try:
137
+ logger(f"Starting to process {file.filename}")
138
+ new_content = process_pdf(filepath)
139
+ success = {"message": f"Successfully uploaded {file.filename}"}
140
+ success.update(new_content)
141
+ return success
142
+
143
+ except Exception as e:
144
+ return {"message": f"Failed to extract text from PDF: {str(e)}"}
145
+
146
+ elif file.filename.endswith('.txt'):
147
+
148
+ with open(filepath, 'wb') as f:
149
+ f.write(contents)
150
+
151
+ filepath2 = os.path.join(datadir2, file.filename)
152
+ with open(filepath2, 'wb') as f:
153
+ f.write(contents)
154
+
155
+ try:
156
+ logger(f"Reading {file.filename}")
157
+ new_content = process_txt(filepath)
158
+ success = {"message": f"Successfully uploaded {file.filename}"}
159
+ success.update(new_content)
160
+ return success
161
+
162
+ except Exception as e:
163
+ return {"message": f"Failed to extract text from TXT: {str(e)}"}
164
+
165
+ else:
166
+ return {"message": "Only PDF & txt files are accepted"}
167
+
168
+
169
+ # @app.post("/create_index/")
170
+ async def create_index():
171
+ """ Create an index for the uploaded files """
172
+
173
+ logger("Creating index for uploaded files")
174
+ try:
175
+ msg = index_data()
176
+ return {"message": msg}
177
+ except Exception as e:
178
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
179
+
180
+
181
+ class Question(BaseModel):
182
+ question: str
183
+
184
+ # @app.post("/ask/")
185
+ async def hybrid_search(question: Question):
186
+ logger(f"Processing question: {question.question}")
187
+ try:
188
+ search_results = vector_search(question.question)
189
+ logger(f"Answer: {search_results}")
190
+ return {"answer": search_results}
191
+ except Exception as e:
192
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
193
+
194
+
195
+ # @app.post("/ragit/")
196
+ async def ragit(question: Question):
197
+ logger(f"Processing question: {question.question}")
198
+ try:
199
+ search_results = vector_search_raw(question.question)
200
+ logger(f"Search results generated: {search_results}")
201
+
202
+ answer = rag_it(question.question, search_results)
203
+
204
+ logger(f"Answer: {answer}")
205
+ return {"answer": answer}
206
+ except Exception as e:
207
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
208
+
209
+
210
+ if __name__ == '__main__':
211
+ import uvicorn
212
+ from os import getenv
213
+ port = int(getenv("PORT", 80))
214
+ print(f"Starting server on port {port}")
215
+ reload = True if environment == "dev" else False
216
+ uvicorn.run("main:app", host="0.0.0.0", port=port, reload=reload)
217
+
218
+
219
+ # Examples:
220
+ # curl -X POST "http://localhost:8001/upload" -F "file=@test.pdf"
221
+ # curl -X DELETE "http://localhost:8001/erase_data/"
222
+ # curl -X GET "http://localhost:8001/list_files/"
223
+
224
+ # hf space is at https://jpbianchi-multirag.hf.space/
225
+ # code given by https://jpbianchi-multirag.hf.space/docs
226
+ # Space must be public
227
+ # curl -X POST "https://jpbianchi-multirag.hf.space/upload/" -F "file=@test.pdf"
228
+
229
+ # curl -X POST http://localhost:80/ask/ -H "Content-Type: application/json" -d '{"question": "what is Amazon loss"}'
230
+ # curl -X POST http://localhost:80/ragit/ -H "Content-Type: application/json" -d '{"question": "Does ATT have postpaid phone customers?"}'
231
+ # see more in notebook upload_index.ipynb
app/notebooks/__init__.py ADDED
File without changes
app/requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
  requests==2.31.0
2
  pydantic==2.7.1
3
  pydantic_core==2.18.2
4
- fastapi
5
- uvicorn[standard]
6
  pdfplumber==0.11.0
7
  weaviate-client==4.5.4
8
  PyPDF2==3.0.1
@@ -21,4 +21,8 @@ langchain-community==0.0.38
21
  langchain-core==0.1.52
22
  langchain-text-splitters==0.0.1
23
  python-multipart==0.0.9
24
- tenacity==8.2.3
 
 
 
 
 
1
  requests==2.31.0
2
  pydantic==2.7.1
3
  pydantic_core==2.18.2
4
+ fastapi==0.110.0
5
+ uvicorn==0.20.0
6
  pdfplumber==0.11.0
7
  weaviate-client==4.5.4
8
  PyPDF2==3.0.1
 
21
  langchain-core==0.1.52
22
  langchain-text-splitters==0.0.1
23
  python-multipart==0.0.9
24
+ tenacity==8.2.3
25
+ typer
26
+ # https://hub.guardrailsai.com/tokens
27
+ guardrails-ai<=0.4.2 # API KEY doesn not work above that version
28
+ loguru==0.7.2 # used in reranker
app/settings.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
2
 
3
- datadir = '../data' # will be used in main.py
 
4
  parquet_file = os.path.join(datadir, 'text_vectors.parquet') # used by the files in 'engine'
 
1
  import os
2
 
3
+ datadir = 'data' # will be used in main.py
4
+ datadir2 = 'assets/data' # backup since data can be emptied
5
  parquet_file = os.path.join(datadir, 'text_vectors.parquet') # used by the files in 'engine'
app/tests/test_main.py CHANGED
@@ -4,7 +4,7 @@ from main import app
4
 
5
  from fastapi.testclient import TestClient
6
 
7
- from settings import datadir
8
 
9
  client = TestClient(app)
10
 
 
4
 
5
  from fastapi.testclient import TestClient
6
 
7
+ from app.settings import datadir
8
 
9
  client = TestClient(app)
10
 
assets/IO_logo.webp ADDED
assets/OI_logo.jpg ADDED
assets/amazon_forecast.jpg ADDED
assets/amazon_idiot.jpg ADDED
assets/favicon.ico ADDED
assets/homepage.jpg ADDED
assets/irrelevant_amazon.jpg ADDED
assets/multirag_good.jpeg ADDED