Asaad Almutareb commited on
Commit
a0df48e
1 Parent(s): e0a73da

added websocket to hf_mixtral_agent

Browse files
innovation_pathfinder_ai/backend/app/api/v1/agents/hf_mixtral_agent.py CHANGED
@@ -7,14 +7,14 @@ from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser
7
  from langchain.tools.render import render_text_description
8
  import os
9
  from dotenv import load_dotenv
10
- from innovation_pathfinder_ai.backend.app.structured_tools.structured_tools import (
11
  arxiv_search, get_arxiv_paper, google_search, wikipedia_search, knowledgeBase_search, memory_search
12
  )
13
  from fastapi import APIRouter, WebSocket, WebSocketDisconnect
14
  from langchain.prompts import PromptTemplate
15
- from innovation_pathfinder_ai.backend.app.templates.react_json_with_memory import template_system
16
- from innovation_pathfinder_ai.backend.app.utils import logger
17
- from innovation_pathfinder_ai.backend.app.utils import generate_uuid
18
  from langchain.globals import set_llm_cache
19
  from langchain.cache import SQLiteCache
20
 
@@ -32,7 +32,7 @@ LANGCHAIN_PROJECT = os.getenv('LANGCHAIN_PROJECT')
32
 
33
  router = APIRouter()
34
 
35
- @router.websocket("")
36
  async def websocket_endpoint(websocket: WebSocket):
37
  await websocket.accept()
38
 
@@ -41,16 +41,16 @@ async def websocket_endpoint(websocket: WebSocket):
41
  data = await websocket.receive_json()
42
  user_message = data["message"]
43
 
44
- resp = IChatResponse(
45
- sender="you",
46
- message=user_message_card.to_dict(),
47
- type="start",
48
- message_id=generate_uuid(),
49
- id=generate_uuid(),
50
- )
51
 
52
- await websocket.send_json(resp.dict())
53
- message_id: str = generate_uuid()
54
  # custom_handler = CustomFinalStreamingStdOutCallbackHandler(
55
  # websocket, message_id=message_id
56
  # )
 
7
  from langchain.tools.render import render_text_description
8
  import os
9
  from dotenv import load_dotenv
10
+ from app.structured_tools.structured_tools import (
11
  arxiv_search, get_arxiv_paper, google_search, wikipedia_search, knowledgeBase_search, memory_search
12
  )
13
  from fastapi import APIRouter, WebSocket, WebSocketDisconnect
14
  from langchain.prompts import PromptTemplate
15
+ from app.templates.react_json_with_memory import template_system
16
+ from app.utils import logger
17
+ from app.utils import utils
18
  from langchain.globals import set_llm_cache
19
  from langchain.cache import SQLiteCache
20
 
 
32
 
33
  router = APIRouter()
34
 
35
+ @router.websocket("/agent")
36
  async def websocket_endpoint(websocket: WebSocket):
37
  await websocket.accept()
38
 
 
41
  data = await websocket.receive_json()
42
  user_message = data["message"]
43
 
44
+ # resp = IChatResponse(
45
+ # sender="you",
46
+ # message=user_message_card.to_dict(),
47
+ # type="start",
48
+ # message_id=generate_uuid(),
49
+ # id=generate_uuid(),
50
+ # )
51
 
52
+ # await websocket.send_json(resp.dict())
53
+ message_id: str = utils.generate_uuid()
54
  # custom_handler = CustomFinalStreamingStdOutCallbackHandler(
55
  # websocket, message_id=message_id
56
  # )
innovation_pathfinder_ai/backend/app/api/v1/endpoints/add_to_kb.py CHANGED
@@ -1,8 +1,8 @@
1
  from fastapi import APIRouter
2
- from innovation_pathfinder_ai.backend.app.utils.utils import extract_urls
3
- from innovation_pathfinder_ai.backend.app.utils import logger
4
- from innovation_pathfinder_ai.backend.app.vector_store import initialize_chroma_db
5
- from innovation_pathfinder_ai.backend.app.utils.utils import (
6
  generate_uuid
7
  )
8
  from langchain_community.vectorstores import Chroma
 
1
  from fastapi import APIRouter
2
+ from app.utils.utils import extract_urls
3
+ from app.utils import logger
4
+ from app.vector_store import initialize_chroma_db
5
+ from app.utils.utils import (
6
  generate_uuid
7
  )
8
  from langchain_community.vectorstores import Chroma
innovation_pathfinder_ai/backend/app/crud/db_handler.py CHANGED
@@ -1,6 +1,6 @@
1
  from sqlmodel import SQLModel, create_engine, Session, select
2
- from innovation_pathfinder_ai.backend.app.database.db_schema import Sources
3
- from innovation_pathfinder_ai.backend.app.utils.logger import get_console_logger
4
  import os
5
  from dotenv import load_dotenv
6
 
 
1
  from sqlmodel import SQLModel, create_engine, Session, select
2
+ from app.database.db_schema import Sources
3
+ from app.utils.logger import get_console_logger
4
  import os
5
  from dotenv import load_dotenv
6
 
innovation_pathfinder_ai/backend/app/main.py CHANGED
@@ -1,5 +1,5 @@
1
  from fastapi import FastAPI
2
- from innovation_pathfinder_ai.backend.app.api.v1.api import api_router as api_router_v1
3
  #from app.core.config import settings
4
  from fastapi.middleware.cors import CORSMiddleware
5
 
 
1
  from fastapi import FastAPI
2
+ from app.api.v1.api import api_router as api_router_v1
3
  #from app.core.config import settings
4
  from fastapi.middleware.cors import CORSMiddleware
5
 
innovation_pathfinder_ai/backend/app/structured_tools/structured_tools.py CHANGED
@@ -15,24 +15,24 @@ import ast
15
  import chromadb
16
 
17
  # hacky and should be replaced with a database
18
- from innovation_pathfinder_ai.source_container.container import (
19
- all_sources
20
- )
21
- from innovation_pathfinder_ai.utils.utils import (
22
  parse_list_to_dicts, format_wiki_summaries, format_arxiv_documents, format_search_results
23
  )
24
- from backend.app.crud.db_handler import (
25
  add_many
26
  )
27
 
28
- from innovation_pathfinder_ai.vector_store.chroma_vector_store import (
29
  add_pdf_to_vector_store
30
  )
31
- from innovation_pathfinder_ai.utils.utils import (
32
  create_wikipedia_urls_from_text, create_folder_if_not_exists,
33
  )
34
  import os
35
- # from innovation_pathfinder_ai.utils import create_wikipedia_urls_from_text
36
 
37
  persist_directory = os.getenv('VECTOR_DATABASE_LOCATION')
38
 
@@ -93,14 +93,14 @@ def knowledgeBase_search(query:str) -> str:
93
  def arxiv_search(query: str) -> str:
94
  """Search arxiv database for scientific research papers and studies. This is your primary online information source.
95
  always check it first when you search for additional information, before using any other online tool."""
96
- global all_sources
97
  arxiv_retriever = ArxivRetriever(load_max_docs=3)
98
  data = arxiv_retriever.invoke(query)
99
  meta_data = [i.metadata for i in data]
100
  formatted_sources = format_arxiv_documents(data)
101
- all_sources += formatted_sources
102
  parsed_sources = parse_list_to_dicts(formatted_sources)
103
- add_many(parsed_sources)
104
 
105
  return data.__str__()
106
 
@@ -162,28 +162,28 @@ def embed_arvix_paper(paper_id:str) -> None:
162
  @tool
163
  def wikipedia_search(query: str) -> str:
164
  """Search Wikipedia for additional information to expand on research papers or when no papers can be found."""
165
- global all_sources
166
 
167
  api_wrapper = WikipediaAPIWrapper()
168
  wikipedia_search = WikipediaQueryRun(api_wrapper=api_wrapper)
169
  wikipedia_results = wikipedia_search.run(query)
170
  formatted_summaries = format_wiki_summaries(wikipedia_results)
171
- all_sources += formatted_summaries
172
  parsed_summaries = parse_list_to_dicts(formatted_summaries)
173
- add_many(parsed_summaries)
174
  #all_sources += create_wikipedia_urls_from_text(wikipedia_results)
175
  return wikipedia_results
176
 
177
  @tool
178
  def google_search(query: str) -> str:
179
  """Search Google for additional results when you can't answer questions using arxiv search or wikipedia search."""
180
- global all_sources
181
 
182
  websearch = GoogleSearchAPIWrapper()
183
  search_results:dict = websearch.results(query, 3)
184
  cleaner_sources =format_search_results(search_results)
185
  parsed_csources = parse_list_to_dicts(cleaner_sources)
186
- add_many(parsed_csources)
187
- all_sources += cleaner_sources
188
 
189
  return cleaner_sources.__str__()
 
15
  import chromadb
16
 
17
  # hacky and should be replaced with a database
18
+ # from app.source_container.container import (
19
+ # all_sources
20
+ # )
21
+ from app.utils.utils import (
22
  parse_list_to_dicts, format_wiki_summaries, format_arxiv_documents, format_search_results
23
  )
24
+ from app.crud.db_handler import (
25
  add_many
26
  )
27
 
28
+ from app.vector_store.chroma_vector_store import (
29
  add_pdf_to_vector_store
30
  )
31
+ from app.utils.utils import (
32
  create_wikipedia_urls_from_text, create_folder_if_not_exists,
33
  )
34
  import os
35
+ # from app.utils import create_wikipedia_urls_from_text
36
 
37
  persist_directory = os.getenv('VECTOR_DATABASE_LOCATION')
38
 
 
93
  def arxiv_search(query: str) -> str:
94
  """Search arxiv database for scientific research papers and studies. This is your primary online information source.
95
  always check it first when you search for additional information, before using any other online tool."""
96
+ #global all_sources
97
  arxiv_retriever = ArxivRetriever(load_max_docs=3)
98
  data = arxiv_retriever.invoke(query)
99
  meta_data = [i.metadata for i in data]
100
  formatted_sources = format_arxiv_documents(data)
101
+ #all_sources += formatted_sources
102
  parsed_sources = parse_list_to_dicts(formatted_sources)
103
+ #add_many(parsed_sources)
104
 
105
  return data.__str__()
106
 
 
162
  @tool
163
  def wikipedia_search(query: str) -> str:
164
  """Search Wikipedia for additional information to expand on research papers or when no papers can be found."""
165
+ #global all_sources
166
 
167
  api_wrapper = WikipediaAPIWrapper()
168
  wikipedia_search = WikipediaQueryRun(api_wrapper=api_wrapper)
169
  wikipedia_results = wikipedia_search.run(query)
170
  formatted_summaries = format_wiki_summaries(wikipedia_results)
171
+ #all_sources += formatted_summaries
172
  parsed_summaries = parse_list_to_dicts(formatted_summaries)
173
+ #add_many(parsed_summaries)
174
  #all_sources += create_wikipedia_urls_from_text(wikipedia_results)
175
  return wikipedia_results
176
 
177
  @tool
178
  def google_search(query: str) -> str:
179
  """Search Google for additional results when you can't answer questions using arxiv search or wikipedia search."""
180
+ #global all_sources
181
 
182
  websearch = GoogleSearchAPIWrapper()
183
  search_results:dict = websearch.results(query, 3)
184
  cleaner_sources =format_search_results(search_results)
185
  parsed_csources = parse_list_to_dicts(cleaner_sources)
186
+ #add_many(parsed_csources)
187
+ #all_sources += cleaner_sources
188
 
189
  return cleaner_sources.__str__()
innovation_pathfinder_ai/backend/app/utils/utils.py CHANGED
@@ -3,7 +3,7 @@ import datetime
3
  import os
4
  import uuid
5
 
6
- from innovation_pathfinder_ai.backend.app.utils import logger
7
 
8
  logger = logger.get_console_logger("utils")
9
 
 
3
  import os
4
  import uuid
5
 
6
+ from app.utils import logger
7
 
8
  logger = logger.get_console_logger("utils")
9
 
innovation_pathfinder_ai/backend/app/vector_store/chroma_vector_store.py CHANGED
@@ -20,7 +20,7 @@ from langchain_community.vectorstores import Chroma
20
  from langchain_community.embeddings.sentence_transformer import (
21
  SentenceTransformerEmbeddings,
22
  )
23
- from innovation_pathfinder_ai.utils.utils import (
24
  generate_uuid
25
  )
26
  import dotenv
 
20
  from langchain_community.embeddings.sentence_transformer import (
21
  SentenceTransformerEmbeddings,
22
  )
23
+ from app.utils.utils import (
24
  generate_uuid
25
  )
26
  import dotenv
innovation_pathfinder_ai/frontend/app.py CHANGED
@@ -1,47 +1,34 @@
1
  from fastapi import FastAPI
2
  import gradio as gr
3
  from gradio.themes.base import Base
4
- from innovation_pathfinder_ai.backend.app.api.v1.agents.hf_mixtral_agent import agent_executor
5
- from innovation_pathfinder_ai.source_container.container import (
6
- all_sources
7
- )
8
- from innovation_pathfinder_ai.backend.app.utils.utils import extract_urls
9
- from innovation_pathfinder_ai.backend.app.utils import logger
10
-
11
- from innovation_pathfinder_ai.backend.app.utils.utils import (
12
- generate_uuid
13
- )
14
  from langchain_community.vectorstores import Chroma
15
 
16
- import chromadb
 
 
17
  import dotenv
18
  import os
19
 
20
  dotenv.load_dotenv()
21
  persist_directory = os.getenv('VECTOR_DATABASE_LOCATION')
22
 
23
- logger = logger.get_console_logger("app")
24
 
25
  app = FastAPI()
26
 
27
- def initialize_chroma_db() -> Chroma:
28
- collection_name = os.getenv('CONVERSATION_COLLECTION_NAME')
29
-
30
- client = chromadb.PersistentClient(
31
- path=persist_directory
32
- )
33
-
34
- collection = client.get_or_create_collection(
35
- name=collection_name,
36
- )
37
-
38
- return collection
39
-
40
-
41
-
42
  if __name__ == "__main__":
43
 
44
- db = initialize_chroma_db()
45
 
46
  def add_text(history, text):
47
  history = history + [(text, None)]
@@ -53,35 +40,51 @@ if __name__ == "__main__":
53
  # Example for calling generate_uuid from the backend
54
  # response = requests.post("http://localhost:8000/add-document")
55
  #current_id = response.text
56
- sources = extract_urls(all_sources)
57
- src_list = '\n'.join(sources)
58
- current_id = generate_uuid()
59
- db.add(
60
- ids=[current_id],
61
- documents=[response['output']],
62
- metadatas=[
63
- {
64
- "human_message":history[-1][0],
65
- "sources": 'Internal Knowledge Base From: \n\n' + src_list
66
- }
67
- ]
68
- )
69
- if not sources:
70
- response_w_sources = response['output']+"\n\n\n Sources: \n\n\n Internal knowledge base"
71
- else:
72
- response_w_sources = response['output']+"\n\n\n Sources: \n\n\n"+src_list
73
- history[-1][1] = response_w_sources
74
- all_sources.clear()
75
  return history
76
 
77
  def infer(question, history):
78
- query = question
79
- result = agent_executor.invoke(
80
- {
81
- "input": question,
82
- "chat_history": history
83
- }
84
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  return result
86
 
87
  def vote(data: gr.LikeData):
@@ -122,7 +125,7 @@ if __name__ == "__main__":
122
  gr.Markdown("Nothing yet...")
123
 
124
  demo.queue()
125
- demo.launch(debug=True, favicon_path="innovation_pathfinder_ai/assets/favicon.ico", share=True)
126
 
127
  x = 0 # for debugging purposes
128
  app = gr.mount_gradio_app(app, demo, path="/")
 
1
  from fastapi import FastAPI
2
  import gradio as gr
3
  from gradio.themes.base import Base
4
+ #from innovation_pathfinder_ai.backend.app.api.v1.agents.hf_mixtral_agent import agent_executor
5
+ #from innovation_pathfinder_ai.source_container.container import (
6
+ # all_sources
7
+ #)
8
+ #from innovation_pathfinder_ai.backend.app.utils.utils import extract_urls
9
+ #from innovation_pathfinder_ai.backend.app.utils import logger
10
+ #from innovation_pathfinder_ai.backend.app.vector_store.chroma_vector_store import initialize_chroma_db
11
+ #from innovation_pathfinder_ai.backend.app.utils.utils import (
12
+ # generate_uuid
13
+ #)
14
  from langchain_community.vectorstores import Chroma
15
 
16
+ import asyncio
17
+ import websockets
18
+ import json
19
  import dotenv
20
  import os
21
 
22
  dotenv.load_dotenv()
23
  persist_directory = os.getenv('VECTOR_DATABASE_LOCATION')
24
 
25
+ #logger = logger.get_console_logger("app")
26
 
27
  app = FastAPI()
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  if __name__ == "__main__":
30
 
31
+ #db = initialize_chroma_db()
32
 
33
  def add_text(history, text):
34
  history = history + [(text, None)]
 
40
  # Example for calling generate_uuid from the backend
41
  # response = requests.post("http://localhost:8000/add-document")
42
  #current_id = response.text
43
+ # sources = extract_urls(all_sources)
44
+ # src_list = '\n'.join(sources)
45
+ # current_id = generate_uuid()
46
+ # db.add(
47
+ # ids=[current_id],
48
+ # documents=[response['output']],
49
+ # metadatas=[
50
+ # {
51
+ # "human_message":history[-1][0],
52
+ # "sources": 'Internal Knowledge Base From: \n\n' + src_list
53
+ # }
54
+ # ]
55
+ # )
56
+ # if not sources:
57
+ # response_w_sources = response['output']+"\n\n\n Sources: \n\n\n Internal knowledge base"
58
+ # else:
59
+ # response_w_sources = response['output']+"\n\n\n Sources: \n\n\n"+src_list
60
+ history[-1][1] = response['output']
61
+ # all_sources.clear()
62
  return history
63
 
64
  def infer(question, history):
65
+ # result = agent_executor.invoke(
66
+ # {
67
+ # "input": question,
68
+ # "chat_history": history
69
+ # }
70
+ # )
71
+ # return result
72
+ async def ask_question_async(question, history):
73
+ uri = "ws://localhost:8000/chat/agent" # Update this URI to your actual WebSocket endpoint
74
+ async with websockets.connect(uri) as websocket:
75
+ # Prepare the message to send (adjust the structure as needed for your backend)
76
+ message_data = {
77
+ "message": question,
78
+ "history": history
79
+ }
80
+ await websocket.send(json.dumps(message_data))
81
+
82
+ # Wait for the response
83
+ response_data = await websocket.recv()
84
+ return json.loads(response_data)
85
+
86
+ # Run the asynchronous function in the synchronous context
87
+ result = asyncio.get_event_loop().run_until_complete(ask_question_async(question, history))
88
  return result
89
 
90
  def vote(data: gr.LikeData):
 
125
  gr.Markdown("Nothing yet...")
126
 
127
  demo.queue()
128
+ demo.launch(debug=True, favicon_path="assets/favicon.ico", share=True)
129
 
130
  x = 0 # for debugging purposes
131
  app = gr.mount_gradio_app(app, demo, path="/")