oceansweep commited on
Commit
5fab6ba
1 Parent(s): a6ecdfa

Update App_Function_Libraries/RAG/RAG_Libary_2.py

Browse files
App_Function_Libraries/RAG/RAG_Libary_2.py CHANGED
@@ -1,332 +1,153 @@
1
- # RAG_Library_2.py
2
- # Description: This script contains the main RAG pipeline function and related functions for the RAG pipeline.
3
- #
4
- # Import necessary modules and functions
5
- import configparser
6
- import logging
7
- import os
8
- from typing import Dict, Any, List, Optional
9
- # Local Imports
10
- from App_Function_Libraries.RAG.ChromaDB_Library import process_and_store_content, vector_search, chroma_client
11
- from App_Function_Libraries.Article_Extractor_Lib import scrape_article
12
- from App_Function_Libraries.DB.DB_Manager import add_media_to_database, search_db, get_unprocessed_media, \
13
- fetch_keywords_for_media
14
- from App_Function_Libraries.Utils.Utils import load_comprehensive_config
15
- #
16
- # 3rd-Party Imports
17
- import openai
18
- #
19
- ########################################################################################################################
20
- #
21
- # Functions:
22
-
23
- # Initialize OpenAI client (adjust this based on your API key management)
24
- openai.api_key = "your-openai-api-key"
25
-
26
- # Get the directory of the current script
27
- current_dir = os.path.dirname(os.path.abspath(__file__))
28
- # Construct the path to the config file
29
- config_path = os.path.join(current_dir, 'Config_Files', 'config.txt')
30
- # Read the config file
31
- config = configparser.ConfigParser()
32
- # Read the configuration file
33
- config.read('config.txt')
34
-
35
- # Main RAG pipeline function
36
- def rag_pipeline(url: str, query: str, api_choice=None) -> Dict[str, Any]:
37
- try:
38
- # Extract content
39
- try:
40
- article_data = scrape_article(url)
41
- content = article_data['content']
42
- title = article_data['title']
43
- except Exception as e:
44
- logging.error(f"Error scraping article: {str(e)}")
45
- return {"error": "Failed to scrape article", "details": str(e)}
46
-
47
- # Store the article in the database and get the media_id
48
- try:
49
- media_id = add_media_to_database(url, title, 'article', content)
50
- except Exception as e:
51
- logging.error(f"Error adding article to database: {str(e)}")
52
- return {"error": "Failed to store article in database", "details": str(e)}
53
-
54
- # Process and store content
55
- collection_name = f"article_{media_id}"
56
- try:
57
- process_and_store_content(content, collection_name, media_id, title)
58
- except Exception as e:
59
- logging.error(f"Error processing and storing content: {str(e)}")
60
- return {"error": "Failed to process and store content", "details": str(e)}
61
-
62
- # Perform searches
63
- try:
64
- vector_results = vector_search(collection_name, query, k=5)
65
- fts_results = search_db(query, ["content"], "", page=1, results_per_page=5)
66
- except Exception as e:
67
- logging.error(f"Error performing searches: {str(e)}")
68
- return {"error": "Failed to perform searches", "details": str(e)}
69
-
70
- # Combine results with error handling for missing 'content' key
71
- all_results = []
72
- for result in vector_results + fts_results:
73
- if isinstance(result, dict) and 'content' in result:
74
- all_results.append(result['content'])
75
- else:
76
- logging.warning(f"Unexpected result format: {result}")
77
- all_results.append(str(result))
78
-
79
- context = "\n".join(all_results)
80
-
81
- # Generate answer using the selected API
82
- try:
83
- answer = generate_answer(api_choice, context, query)
84
- except Exception as e:
85
- logging.error(f"Error generating answer: {str(e)}")
86
- return {"error": "Failed to generate answer", "details": str(e)}
87
-
88
- return {
89
- "answer": answer,
90
- "context": context
91
- }
92
-
93
- except Exception as e:
94
- logging.error(f"Unexpected error in rag_pipeline: {str(e)}")
95
- return {"error": "An unexpected error occurred", "details": str(e)}
96
-
97
-
98
-
99
- # RAG Search with keyword filtering
100
- def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None) -> Dict[str, Any]:
101
- try:
102
- # Load embedding provider from config, or fallback to 'openai'
103
- embedding_provider = config.get('Embeddings', 'provider', fallback='openai')
104
-
105
- # Log the provider used
106
- logging.debug(f"Using embedding provider: {embedding_provider}")
107
-
108
- # Process keywords if provided
109
- keyword_list = [k.strip().lower() for k in keywords.split(',')] if keywords else []
110
- logging.debug(f"enhanced_rag_pipeline - Keywords: {keyword_list}")
111
-
112
- # Fetch relevant media IDs based on keywords if keywords are provided
113
- relevant_media_ids = fetch_relevant_media_ids(keyword_list) if keyword_list else None
114
- logging.debug(f"enhanced_rag_pipeline - relevant media IDs: {relevant_media_ids}")
115
-
116
- # Perform vector search
117
- vector_results = perform_vector_search(query, relevant_media_ids)
118
- logging.debug(f"enhanced_rag_pipeline - Vector search results: {vector_results}")
119
-
120
- # Perform full-text search
121
- fts_results = perform_full_text_search(query, relevant_media_ids)
122
- logging.debug(f"enhanced_rag_pipeline - Full-text search results: {fts_results}")
123
-
124
- # Combine results
125
- all_results = vector_results + fts_results
126
- # FIXME
127
- if not all_results:
128
- logging.info(f"No results found. Query: {query}, Keywords: {keywords}")
129
- return {
130
- "answer": "I couldn't find any relevant information based on your query and keywords.",
131
- "context": ""
132
- }
133
-
134
- # FIXME - Apply Re-Ranking of results here
135
- apply_re_ranking = False
136
- if apply_re_ranking:
137
- # Implement re-ranking logic here
138
- pass
139
- # Extract content from results
140
- context = "\n".join([result['content'] for result in all_results[:10]]) # Limit to top 10 results
141
- logging.debug(f"Context length: {len(context)}")
142
- logging.debug(f"Context: {context[:200]}")
143
- # Generate answer using the selected API
144
- answer = generate_answer(api_choice, context, query)
145
-
146
- return {
147
- "answer": answer,
148
- "context": context
149
- }
150
- except Exception as e:
151
- logging.error(f"Error in enhanced_rag_pipeline: {str(e)}")
152
- return {
153
- "answer": "An error occurred while processing your request.",
154
- "context": ""
155
- }
156
-
157
-
158
- def generate_answer(api_choice: str, context: str, query: str) -> str:
159
- logging.debug("Entering generate_answer function")
160
- config = load_comprehensive_config()
161
- logging.debug(f"Config sections: {config.sections()}")
162
- prompt = f"Context: {context}\n\nQuestion: {query}"
163
- if api_choice == "OpenAI":
164
- from App_Function_Libraries.Summarization_General_Lib import summarize_with_openai
165
- return summarize_with_openai(config['API']['openai_api_key'], prompt, "")
166
- elif api_choice == "Anthropic":
167
- from App_Function_Libraries.Summarization_General_Lib import summarize_with_anthropic
168
- return summarize_with_anthropic(config['API']['anthropic_api_key'], prompt, "")
169
- elif api_choice == "Cohere":
170
- from App_Function_Libraries.Summarization_General_Lib import summarize_with_cohere
171
- return summarize_with_cohere(config['API']['cohere_api_key'], prompt, "")
172
- elif api_choice == "Groq":
173
- from App_Function_Libraries.Summarization_General_Lib import summarize_with_groq
174
- return summarize_with_groq(config['API']['groq_api_key'], prompt, "")
175
- elif api_choice == "OpenRouter":
176
- from App_Function_Libraries.Summarization_General_Lib import summarize_with_openrouter
177
- return summarize_with_openrouter(config['API']['openrouter_api_key'], prompt, "")
178
- elif api_choice == "HuggingFace":
179
- from App_Function_Libraries.Summarization_General_Lib import summarize_with_huggingface
180
- return summarize_with_huggingface(config['API']['huggingface_api_key'], prompt, "")
181
- elif api_choice == "DeepSeek":
182
- from App_Function_Libraries.Summarization_General_Lib import summarize_with_deepseek
183
- return summarize_with_deepseek(config['API']['deepseek_api_key'], prompt, "")
184
- elif api_choice == "Mistral":
185
- from App_Function_Libraries.Summarization_General_Lib import summarize_with_mistral
186
- return summarize_with_mistral(config['API']['mistral_api_key'], prompt, "")
187
- elif api_choice == "Local-LLM":
188
- from App_Function_Libraries.Local_Summarization_Lib import summarize_with_local_llm
189
- return summarize_with_local_llm(config['API']['local_llm_path'], prompt, "")
190
- elif api_choice == "Llama.cpp":
191
- from App_Function_Libraries.Local_Summarization_Lib import summarize_with_llama
192
- return summarize_with_llama(config['API']['llama_api_key'], prompt, "")
193
- elif api_choice == "Kobold":
194
- from App_Function_Libraries.Local_Summarization_Lib import summarize_with_kobold
195
- return summarize_with_kobold(config['API']['kobold_api_key'], prompt, "")
196
- elif api_choice == "Ooba":
197
- from App_Function_Libraries.Local_Summarization_Lib import summarize_with_oobabooga
198
- return summarize_with_oobabooga(config['API']['ooba_api_key'], prompt, "")
199
- elif api_choice == "TabbyAPI":
200
- from App_Function_Libraries.Local_Summarization_Lib import summarize_with_tabbyapi
201
- return summarize_with_tabbyapi(config['API']['tabby_api_key'], prompt, "")
202
- elif api_choice == "vLLM":
203
- from App_Function_Libraries.Local_Summarization_Lib import summarize_with_vllm
204
- return summarize_with_vllm(config['API']['vllm_api_key'], prompt, "")
205
- elif api_choice == "ollama":
206
- from App_Function_Libraries.Local_Summarization_Lib import summarize_with_ollama
207
- return summarize_with_ollama(config['API']['ollama_api_key'], prompt, "")
208
- else:
209
- raise ValueError(f"Unsupported API choice: {api_choice}")
210
-
211
- # Function to preprocess and store all existing content in the database
212
- def preprocess_all_content():
213
- unprocessed_media = get_unprocessed_media()
214
- for row in unprocessed_media:
215
- media_id = row[0]
216
- content = row[1]
217
- media_type = row[2]
218
- collection_name = f"{media_type}_{media_id}"
219
- process_and_store_content(content, collection_name, media_id, "")
220
-
221
-
222
- def perform_vector_search(query: str, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]:
223
- all_collections = chroma_client.list_collections()
224
- vector_results = []
225
- for collection in all_collections:
226
- collection_results = vector_search(collection.name, query, k=5)
227
- filtered_results = [
228
- result for result in collection_results
229
- if relevant_media_ids is None or result['metadata'].get('media_id') in relevant_media_ids
230
- ]
231
- vector_results.extend(filtered_results)
232
- return vector_results
233
-
234
-
235
- def perform_full_text_search(query: str, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]:
236
- fts_results = search_db(query, ["content"], "", page=1, results_per_page=5)
237
- filtered_fts_results = [
238
- {
239
- "content": result['content'],
240
- "metadata": {"media_id": result['id']}
241
- }
242
- for result in fts_results
243
- if relevant_media_ids is None or result['id'] in relevant_media_ids
244
- ]
245
- return filtered_fts_results
246
-
247
-
248
- def fetch_relevant_media_ids(keywords: List[str]) -> List[int]:
249
- relevant_ids = set()
250
- try:
251
- for keyword in keywords:
252
- media_ids = fetch_keywords_for_media(keyword)
253
- relevant_ids.update(media_ids)
254
- except Exception as e:
255
- logging.error(f"Error fetching relevant media IDs: {str(e)}")
256
- return list(relevant_ids)
257
-
258
-
259
- def filter_results_by_keywords(results: List[Dict[str, Any]], keywords: List[str]) -> List[Dict[str, Any]]:
260
- if not keywords:
261
- return results
262
-
263
- filtered_results = []
264
- for result in results:
265
- try:
266
- metadata = result.get('metadata', {})
267
- if metadata is None:
268
- logging.warning(f"No metadata found for result: {result}")
269
- continue
270
- if not isinstance(metadata, dict):
271
- logging.warning(f"Unexpected metadata type: {type(metadata)}. Expected dict.")
272
- continue
273
-
274
- media_id = metadata.get('media_id')
275
- if media_id is None:
276
- logging.warning(f"No media_id found in metadata: {metadata}")
277
- continue
278
-
279
- media_keywords = fetch_keywords_for_media(media_id)
280
- if any(keyword.lower() in [mk.lower() for mk in media_keywords] for keyword in keywords):
281
- filtered_results.append(result)
282
- except Exception as e:
283
- logging.error(f"Error processing result: {result}. Error: {str(e)}")
284
-
285
- return filtered_results
286
-
287
- # FIXME: to be implememted
288
- def extract_media_id_from_result(result: str) -> Optional[int]:
289
- # Implement this function based on how you store the media_id in your results
290
- # For example, if it's stored at the beginning of each result:
291
- try:
292
- return int(result.split('_')[0])
293
- except (IndexError, ValueError):
294
- logging.error(f"Failed to extract media_id from result: {result}")
295
- return None
296
-
297
-
298
-
299
-
300
- # Example usage:
301
- # 1. Initialize the system:
302
- # create_tables(db) # Ensure FTS tables are set up
303
- #
304
- # 2. Create ChromaDB
305
- # chroma_client = ChromaDBClient()
306
- #
307
- # 3. Create Embeddings
308
- # Store embeddings in ChromaDB
309
- # preprocess_all_content() or create_embeddings()
310
- #
311
- # 4. Perform RAG search across all content:
312
- # result = rag_search("What are the key points about climate change?")
313
- # print(result['answer'])
314
- #
315
- # (Extra)5. Perform RAG on a specific URL:
316
- # result = rag_pipeline("https://example.com/article", "What is the main topic of this article?")
317
- # print(result['answer'])
318
- #
319
- ########################################################################################################################
320
-
321
-
322
- ############################################################################################################
323
- #
324
- # ElasticSearch Retriever
325
-
326
- # https://github.com/langchain-ai/langchain/tree/44e3e2391c48bfd0a8e6a20adde0b6567f4f43c3/templates/rag-elasticsearch
327
- #
328
- # https://github.com/langchain-ai/langchain/tree/44e3e2391c48bfd0a8e6a20adde0b6567f4f43c3/templates/rag-self-query
329
-
330
- #
331
- # End of RAG_Library_2.py
332
- ############################################################################################################
 
1
+ # RAG_Library_2.py
2
+ # Description: This script contains the main RAG pipeline function and related functions for the RAG pipeline.
3
+ #
4
+ # Import necessary modules and functions
5
+ import configparser
6
+ import logging
7
+ import os
8
+ from typing import Dict, Any, List, Optional
9
+ # Local Imports
10
+ #from App_Function_Libraries.RAG.ChromaDB_Library import process_and_store_content, vector_search, chroma_client
11
+ from App_Function_Libraries.Article_Extractor_Lib import scrape_article
12
+ from App_Function_Libraries.DB.DB_Manager import add_media_to_database, search_db, get_unprocessed_media, \
13
+ fetch_keywords_for_media
14
+ from App_Function_Libraries.Utils.Utils import load_comprehensive_config
15
+ #
16
+ # 3rd-Party Imports
17
+ import openai
18
+ #
19
+ ########################################################################################################################
20
+ #
21
+ # Functions:
22
+
23
+ # Initialize OpenAI client (adjust this based on your API key management)
24
+ openai.api_key = "your-openai-api-key"
25
+
26
+ # Get the directory of the current script
27
+ current_dir = os.path.dirname(os.path.abspath(__file__))
28
+ # Construct the path to the config file
29
+ config_path = os.path.join(current_dir, 'Config_Files', 'config.txt')
30
+ # Read the config file
31
+ config = configparser.ConfigParser()
32
+ # Read the configuration file
33
+ config.read('config.txt')
34
+
35
+
36
+
37
+
38
+
39
+
40
+
41
+ def generate_answer(api_choice: str, context: str, query: str) -> str:
42
+ logging.debug("Entering generate_answer function")
43
+ config = load_comprehensive_config()
44
+ logging.debug(f"Config sections: {config.sections()}")
45
+ prompt = f"Context: {context}\n\nQuestion: {query}"
46
+ if api_choice == "OpenAI":
47
+ from App_Function_Libraries.Summarization_General_Lib import summarize_with_openai
48
+ return summarize_with_openai(config['API']['openai_api_key'], prompt, "")
49
+ elif api_choice == "Anthropic":
50
+ from App_Function_Libraries.Summarization_General_Lib import summarize_with_anthropic
51
+ return summarize_with_anthropic(config['API']['anthropic_api_key'], prompt, "")
52
+ elif api_choice == "Cohere":
53
+ from App_Function_Libraries.Summarization_General_Lib import summarize_with_cohere
54
+ return summarize_with_cohere(config['API']['cohere_api_key'], prompt, "")
55
+ elif api_choice == "Groq":
56
+ from App_Function_Libraries.Summarization_General_Lib import summarize_with_groq
57
+ return summarize_with_groq(config['API']['groq_api_key'], prompt, "")
58
+ elif api_choice == "OpenRouter":
59
+ from App_Function_Libraries.Summarization_General_Lib import summarize_with_openrouter
60
+ return summarize_with_openrouter(config['API']['openrouter_api_key'], prompt, "")
61
+ elif api_choice == "HuggingFace":
62
+ from App_Function_Libraries.Summarization_General_Lib import summarize_with_huggingface
63
+ return summarize_with_huggingface(config['API']['huggingface_api_key'], prompt, "")
64
+ elif api_choice == "DeepSeek":
65
+ from App_Function_Libraries.Summarization_General_Lib import summarize_with_deepseek
66
+ return summarize_with_deepseek(config['API']['deepseek_api_key'], prompt, "")
67
+ elif api_choice == "Mistral":
68
+ from App_Function_Libraries.Summarization_General_Lib import summarize_with_mistral
69
+ return summarize_with_mistral(config['API']['mistral_api_key'], prompt, "")
70
+ elif api_choice == "Local-LLM":
71
+ from App_Function_Libraries.Local_Summarization_Lib import summarize_with_local_llm
72
+ return summarize_with_local_llm(config['API']['local_llm_path'], prompt, "")
73
+ elif api_choice == "Llama.cpp":
74
+ from App_Function_Libraries.Local_Summarization_Lib import summarize_with_llama
75
+ return summarize_with_llama(config['API']['llama_api_key'], prompt, "")
76
+ elif api_choice == "Kobold":
77
+ from App_Function_Libraries.Local_Summarization_Lib import summarize_with_kobold
78
+ return summarize_with_kobold(config['API']['kobold_api_key'], prompt, "")
79
+ elif api_choice == "Ooba":
80
+ from App_Function_Libraries.Local_Summarization_Lib import summarize_with_oobabooga
81
+ return summarize_with_oobabooga(config['API']['ooba_api_key'], prompt, "")
82
+ elif api_choice == "TabbyAPI":
83
+ from App_Function_Libraries.Local_Summarization_Lib import summarize_with_tabbyapi
84
+ return summarize_with_tabbyapi(config['API']['tabby_api_key'], prompt, "")
85
+ elif api_choice == "vLLM":
86
+ from App_Function_Libraries.Local_Summarization_Lib import summarize_with_vllm
87
+ return summarize_with_vllm(config['API']['vllm_api_key'], prompt, "")
88
+ elif api_choice == "ollama":
89
+ from App_Function_Libraries.Local_Summarization_Lib import summarize_with_ollama
90
+ return summarize_with_ollama(config['API']['ollama_api_key'], prompt, "")
91
+ else:
92
+ raise ValueError(f"Unsupported API choice: {api_choice}")
93
+
94
+
95
+ def perform_full_text_search(query: str, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]:
96
+ fts_results = search_db(query, ["content"], "", page=1, results_per_page=5)
97
+ filtered_fts_results = [
98
+ {
99
+ "content": result['content'],
100
+ "metadata": {"media_id": result['id']}
101
+ }
102
+ for result in fts_results
103
+ if relevant_media_ids is None or result['id'] in relevant_media_ids
104
+ ]
105
+ return filtered_fts_results
106
+
107
+
108
+ def fetch_relevant_media_ids(keywords: List[str]) -> List[int]:
109
+ relevant_ids = set()
110
+ try:
111
+ for keyword in keywords:
112
+ media_ids = fetch_keywords_for_media(keyword)
113
+ relevant_ids.update(media_ids)
114
+ except Exception as e:
115
+ logging.error(f"Error fetching relevant media IDs: {str(e)}")
116
+ return list(relevant_ids)
117
+
118
+
119
+
120
+
121
+ # Example usage:
122
+ # 1. Initialize the system:
123
+ # create_tables(db) # Ensure FTS tables are set up
124
+ #
125
+ # 2. Create ChromaDB
126
+ # chroma_client = ChromaDBClient()
127
+ #
128
+ # 3. Create Embeddings
129
+ # Store embeddings in ChromaDB
130
+ # preprocess_all_content() or create_embeddings()
131
+ #
132
+ # 4. Perform RAG search across all content:
133
+ # result = rag_search("What are the key points about climate change?")
134
+ # print(result['answer'])
135
+ #
136
+ # (Extra)5. Perform RAG on a specific URL:
137
+ # result = rag_pipeline("https://example.com/article", "What is the main topic of this article?")
138
+ # print(result['answer'])
139
+ #
140
+ ########################################################################################################################
141
+
142
+
143
+ ############################################################################################################
144
+ #
145
+ # ElasticSearch Retriever
146
+
147
+ # https://github.com/langchain-ai/langchain/tree/44e3e2391c48bfd0a8e6a20adde0b6567f4f43c3/templates/rag-elasticsearch
148
+ #
149
+ # https://github.com/langchain-ai/langchain/tree/44e3e2391c48bfd0a8e6a20adde0b6567f4f43c3/templates/rag-self-query
150
+
151
+ #
152
+ # End of RAG_Library_2.py
153
+ ############################################################################################################