Daniel Foley commited on
Commit
48ae0fa
·
1 Parent(s): ae94932

added n workers parallelism to metadata fetching

Browse files
Files changed (1) hide show
  1. RAG.py +40 -72
RAG.py CHANGED
@@ -12,49 +12,8 @@ from langchain_core.documents import Document
12
  from langchain_community.retrievers import BM25Retriever
13
  import requests
14
  from typing import Dict, Any, Optional, List, Tuple
15
- import json
16
  import logging
17
-
18
-
19
- import logging
20
- from datetime import datetime
21
- from io import StringIO
22
-
23
- class RunLogger:
24
- def __init__(self, script_name='streamlit_script'):
25
- # Create string buffer to store logs
26
- self.log_buffer = StringIO()
27
-
28
- # Create logger
29
- self.logger = logging.getLogger(script_name)
30
- self.logger.setLevel(logging.INFO)
31
-
32
- # Create handler that writes to our string buffer
33
- handler = logging.StreamHandler(self.log_buffer)
34
- formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
35
- handler.setFormatter(formatter)
36
- self.logger.addHandler(handler)
37
-
38
- self.logger.info("=== Starting new run ===")
39
-
40
- def info(self, message):
41
- self.logger.info(message)
42
-
43
- def error(self, message):
44
- self.logger.error(message)
45
-
46
- def warning(self, message):
47
- self.logger.warning(message)
48
-
49
- def output_logs(self):
50
- """Print all collected logs"""
51
- print("\n=== Run Complete - All Logs ===")
52
- print(self.log_buffer.getvalue())
53
- print("=== End Logs ===\n")
54
-
55
- def __del__(self):
56
- """Ensure logs are output if logger is garbage collected"""
57
- self.output_logs()
58
 
59
  def retrieve(query: str,vectorstore:PineconeVectorStore, k: int = 100) -> Tuple[List[Document], List[float]]:
60
  start = time.time()
@@ -80,7 +39,6 @@ def retrieve(query: str,vectorstore:PineconeVectorStore, k: int = 100) -> Tuple[
80
 
81
  def safe_get_json(url: str) -> Optional[Dict]:
82
  """Safely fetch and parse JSON from a URL."""
83
- print("Fetching JSON")
84
  try:
85
  response = requests.get(url, timeout=10)
86
  response.raise_for_status()
@@ -89,42 +47,52 @@ def safe_get_json(url: str) -> Optional[Dict]:
89
  logging.error(f"Error fetching from {url}: {str(e)}")
90
  return None
91
 
92
- def extract_text_from_json(json_data: Dict) -> str:
93
- """Extract text content from JSON response."""
94
- if not json_data:
95
- return ""
96
-
97
- text_parts = []
98
-
99
- # Handle direct text fields
100
- text_fields = ["title_info_primary_tsi","abstract_tsi","subject_geographic_sim","genre_basic_ssim","genre_specific_ssim","date_tsim"]
101
- for field in text_fields:
102
- if field in json_data['data']['attributes'] and json_data['data']['attributes'][field]:
103
- # print(json_data[field])
104
- text_parts.append(str(json_data['data']['attributes'][field]))
105
 
106
- return " ".join(text_parts) if text_parts else "No content available"
107
-
108
- def rerank(documents: List[Document], query: str) -> List[Document]:
109
- """Ingest more metadata. Rerank documents using BM25"""
 
 
 
 
 
 
 
 
 
 
 
110
  start = time.time()
111
  if not documents:
112
  return []
113
 
114
- full_docs = []
115
  meta_start = time.time()
116
- for doc in documents:
117
- if not doc.metadata.get('source'):
118
- continue
119
-
120
- url = f"https://www.digitalcommonwealth.org/search/{doc.metadata['source']}"
121
- json_data = safe_get_json(f"{url}.json")
 
 
 
122
 
123
- if json_data:
124
- text_content = extract_text_from_json(json_data)
125
- if text_content: # Only add documents with actual content
126
- full_docs.append(Document(page_content=text_content, metadata={"source":doc.metadata['source'],"field":doc.metadata['field'],"URL":url}))
 
 
127
  logging.info(f"Took {time.time()-meta_start} seconds to retrieve all metadata")
 
128
  # If no valid documents were processed, return empty list
129
  if not full_docs:
130
  return []
@@ -133,7 +101,7 @@ def rerank(documents: List[Document], query: str) -> List[Document]:
133
  reranker = BM25Retriever.from_documents(full_docs, k=min(10, len(full_docs)))
134
  reranked_docs = reranker.invoke(query)
135
  logging.info(f"Finished reranking: {time.time()-start}")
136
- return reranked_docs
137
 
138
  def parse_xml_and_query(query:str,xml_string:str) -> str:
139
  """parse xml and return rephrased query"""
 
12
  from langchain_community.retrievers import BM25Retriever
13
  import requests
14
  from typing import Dict, Any, Optional, List, Tuple
 
15
  import logging
16
+ import concurrent.futures
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def retrieve(query: str,vectorstore:PineconeVectorStore, k: int = 100) -> Tuple[List[Document], List[float]]:
19
  start = time.time()
 
39
 
40
  def safe_get_json(url: str) -> Optional[Dict]:
41
  """Safely fetch and parse JSON from a URL."""
 
42
  try:
43
  response = requests.get(url, timeout=10)
44
  response.raise_for_status()
 
47
  logging.error(f"Error fetching from {url}: {str(e)}")
48
  return None
49
 
50
+ def process_single_document(doc: Document) -> Optional[Document]:
51
+ """Process a single document by fetching and extracting metadata."""
52
+ if not doc.metadata.get('source'):
53
+ return None
54
+
55
+ url = f"https://www.digitalcommonwealth.org/search/{doc.metadata['source']}"
56
+ json_data = safe_get_json(f"{url}.json")
 
 
 
 
 
 
57
 
58
+ if json_data:
59
+ text_content = extract_text_from_json(json_data)
60
+ if text_content:
61
+ return Document(
62
+ page_content=text_content,
63
+ metadata={
64
+ "source": doc.metadata['source'],
65
+ "field": doc.metadata['field'],
66
+ "URL": url
67
+ }
68
+ )
69
+ return None
70
+
71
+ def rerank(documents: List[Document], query: str, max_workers: int = 2) -> List[Document]:
72
+ """Ingest more metadata and rerank documents using BM25 with parallel processing."""
73
  start = time.time()
74
  if not documents:
75
  return []
76
 
 
77
  meta_start = time.time()
78
+ full_docs = []
79
+
80
+ # Process documents in parallel using ThreadPoolExecutor
81
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
82
+ # Submit all document processing tasks
83
+ future_to_doc = {
84
+ executor.submit(process_single_document, doc): doc
85
+ for doc in documents
86
+ }
87
 
88
+ # Collect results as they complete
89
+ for future in concurrent.futures.as_completed(future_to_doc):
90
+ processed_doc = future.result()
91
+ if processed_doc:
92
+ full_docs.append(processed_doc)
93
+
94
  logging.info(f"Took {time.time()-meta_start} seconds to retrieve all metadata")
95
+
96
  # If no valid documents were processed, return empty list
97
  if not full_docs:
98
  return []
 
101
  reranker = BM25Retriever.from_documents(full_docs, k=min(10, len(full_docs)))
102
  reranked_docs = reranker.invoke(query)
103
  logging.info(f"Finished reranking: {time.time()-start}")
104
+ return full_docs
105
 
106
  def parse_xml_and_query(query:str,xml_string:str) -> str:
107
  """parse xml and return rephrased query"""