mitulagr2 commited on
Commit
4c88907
1 Parent(s): 82c3144

Update rag.py

Browse files
Files changed (2) hide show
  1. app/main.py +7 -3
  2. app/rag.py +7 -1
app/main.py CHANGED
@@ -66,9 +66,7 @@ async def websocket_endpoint(websocket: WebSocket, client_id: int):
66
 
67
 
68
  @app.post("/upload")
69
- def upload(files: list[UploadFile]):
70
- session_assistant.clear()
71
-
72
  try:
73
  os.makedirs(files_dir)
74
  for file in files:
@@ -86,6 +84,12 @@ def upload(files: list[UploadFile]):
86
  return "Files inserted!"
87
 
88
 
 
 
 
 
 
 
89
  @app.get("/")
90
  def ping():
91
  return "Pong!"
 
66
 
67
 
68
  @app.post("/upload")
69
+ def upload(files: list[UploadFile]):
 
 
70
  try:
71
  os.makedirs(files_dir)
72
  for file in files:
 
84
  return "Files inserted!"
85
 
86
 
87
+ @app.get("/clear")
88
+ def ping():
89
+ session_assistant.clear()
90
+ return "All files have been cleared."
91
+
92
+
93
  @app.get("/")
94
  def ping():
95
  return "Pong!"
app/rag.py CHANGED
@@ -22,6 +22,8 @@ logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
23
 
24
  class ChatPDF:
 
 
25
  def __init__(self):
26
  self.text_parser = SentenceSplitter(chunk_size=512, chunk_overlap=24)
27
 
@@ -70,6 +72,7 @@ class ChatPDF:
70
 
71
  logger.info("enumerating docs")
72
  for doc_idx, doc in enumerate(docs):
 
73
  curr_text_chunks = self.text_parser.split_text(doc.text)
74
  text_chunks.extend(curr_text_chunks)
75
  doc_ids.extend([doc_idx] * len(curr_text_chunks))
@@ -104,8 +107,11 @@ class ChatPDF:
104
 
105
  def ask(self, query: str):
106
  logger.info("retrieving the response to the query")
 
 
 
107
  streaming_response = self.query_engine.query(query)
108
  return streaming_response.response_gen
109
 
110
  def clear(self):
111
- pass
 
22
  logger = logging.getLogger(__name__)
23
 
24
  class ChatPDF:
25
+ pdf_count = 0
26
+
27
  def __init__(self):
28
  self.text_parser = SentenceSplitter(chunk_size=512, chunk_overlap=24)
29
 
 
72
 
73
  logger.info("enumerating docs")
74
  for doc_idx, doc in enumerate(docs):
75
+ self.pdf_count = self.pdf_count + 1
76
  curr_text_chunks = self.text_parser.split_text(doc.text)
77
  text_chunks.extend(curr_text_chunks)
78
  doc_ids.extend([doc_idx] * len(curr_text_chunks))
 
107
 
108
  def ask(self, query: str):
109
  logger.info("retrieving the response to the query")
110
+ if not self.pdf_count > 0:
111
+ return "Please, add a PDF document first."
112
+
113
  streaming_response = self.query_engine.query(query)
114
  return streaming_response.response_gen
115
 
116
  def clear(self):
117
+ self.pdf_count = 0