isayahc commited on
Commit
3f39329
1 Parent(s): 1dd5b76

cleaned up refactored the tools and agent

Browse files
innovation_pathfinder_ai/source_container/container.py ADDED
@@ -0,0 +1 @@
 
 
1
+ all_sources = []
innovation_pathfinder_ai/structured_tools/structured_tools.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.tools import BaseTool, StructuredTool, tool
2
+ from langchain.retrievers import ArxivRetriever
3
+ from langchain_community.utilities import SerpAPIWrapper
4
+ import arxiv
5
+
6
+ # hacky and should be replaced with a database
7
+ from innovation_pathfinder_ai.source_container.container import (
8
+ all_sources
9
+ )
10
+
11
+ @tool
12
+ def arxiv_search(query: str) -> str:
13
+ """Using the arxiv search and collects metadata."""
14
+ # return "LangChain"
15
+ global all_sources
16
+ arxiv_retriever = ArxivRetriever(load_max_docs=2)
17
+ data = arxiv_retriever.invoke(query)
18
+ meta_data = [i.metadata for i in data]
19
+ # meta_data += all_sources
20
+ # all_sources += meta_data
21
+ all_sources += meta_data
22
+
23
+ # formatted_info = format_info(entry_id, published, title, authors)
24
+
25
+ # formatted_info = format_info_list(all_sources)
26
+
27
+ return meta_data.__str__()
28
+
29
+ @tool
30
+ def get_arxiv_paper(paper_id:str) -> None:
31
+ """Download a paper from axriv to download a paper please input
32
+ the axriv id such as "1605.08386v1" This tool is named get_arxiv_paper
33
+ If you input "http://arxiv.org/abs/2312.02813", This will break the code. Also only do
34
+ "2312.02813". In addition please download one paper at a time. Pleaase keep the inputs/output
35
+ free of additional information only have the id.
36
+ """
37
+ # code from https://lukasschwab.me/arxiv.py/arxiv.html
38
+ paper = next(arxiv.Client().results(arxiv.Search(id_list=[paper_id])))
39
+
40
+ number_without_period = paper_id.replace('.', '')
41
+
42
+ # Download the PDF to a specified directory with a custom filename.
43
+ paper.download_pdf(dirpath="./mydir", filename=f"{number_without_period}.pdf")
44
+
45
+
46
+ @tool
47
+ def google_search(query: str) -> str:
48
+ """Using the google search and collects metadata."""
49
+ # return "LangChain"
50
+ global all_sources
51
+
52
+ x = SerpAPIWrapper()
53
+ search_results:dict = x.results(query)
54
+
55
+
56
+ organic_source = search_results['organic_results']
57
+ # formatted_string = "Title: {title}, link: {link}, snippet: {snippet}".format(**organic_source)
58
+ cleaner_sources = ["Title: {title}, link: {link}, snippet: {snippet}".format(**i) for i in organic_source]
59
+
60
+ all_sources += cleaner_sources
61
+
62
+ return cleaner_sources.__str__()
mixtral_agent.py CHANGED
@@ -1,31 +1,29 @@
1
  # LangChain supports many other chat models. Here, we're using Ollama
2
  from langchain_community.chat_models import ChatOllama
3
- from langchain_core.output_parsers import StrOutputParser
4
  from langchain_core.prompts import ChatPromptTemplate
5
- from langchain.tools.retriever import create_retriever_tool
6
- from langchain_community.utilities import SerpAPIWrapper
7
- from langchain.retrievers import ArxivRetriever
8
- from langchain_core.tools import Tool
9
  from langchain import hub
10
- from langchain.agents import AgentExecutor, load_tools
11
  from langchain.agents.format_scratchpad import format_log_to_str
12
  from langchain.agents.output_parsers import (
13
  ReActJsonSingleInputOutputParser,
14
  )
15
  # Import things that are needed generically
16
- from langchain.pydantic_v1 import BaseModel, Field
17
- from langchain.tools import BaseTool, StructuredTool, tool
18
  from typing import List, Dict
19
- from datetime import datetime
20
  from langchain.tools.render import render_text_description
21
  import os
22
- import arxiv
23
-
24
  import dotenv
 
 
 
 
 
 
 
 
 
25
 
26
  dotenv.load_dotenv()
27
 
28
-
29
  OLLMA_BASE_URL = os.getenv("OLLMA_BASE_URL")
30
 
31
 
@@ -37,89 +35,6 @@ llm = ChatOllama(
37
  )
38
  prompt = ChatPromptTemplate.from_template("Tell me a short joke about {topic}")
39
 
40
- arxiv_retriever = ArxivRetriever(load_max_docs=2)
41
-
42
- from zipfile import ZipFile
43
-
44
- def unzip_file(zip_file: str, extract_to: str) -> None:
45
- with ZipFile(zip_file, 'r') as zip_ref:
46
- zip_ref.extractall(extract_to)
47
-
48
-
49
-
50
- def format_info_list(info_list: List[Dict[str, str]]) -> str:
51
- """
52
- Format a list of dictionaries containing information into a single string.
53
-
54
- Args:
55
- info_list (List[Dict[str, str]]): A list of dictionaries containing information.
56
-
57
- Returns:
58
- str: A formatted string containing the information from the list.
59
- """
60
- formatted_strings = []
61
- for info_dict in info_list:
62
- formatted_string = "|"
63
- for key, value in info_dict.items():
64
- if isinstance(value, datetime.date):
65
- value = value.strftime('%Y-%m-%d')
66
- formatted_string += f"'{key}': '{value}', "
67
- formatted_string = formatted_string.rstrip(', ') + "|"
68
- formatted_strings.append(formatted_string)
69
- return '\n'.join(formatted_strings)
70
-
71
- @tool
72
- def arxiv_search(query: str) -> str:
73
- """Using the arxiv search and collects metadata."""
74
- # return "LangChain"
75
- global all_sources
76
- data = arxiv_retriever.invoke(query)
77
- meta_data = [i.metadata for i in data]
78
- # meta_data += all_sources
79
- # all_sources += meta_data
80
- all_sources += meta_data
81
-
82
- # formatted_info = format_info(entry_id, published, title, authors)
83
-
84
- # formatted_info = format_info_list(all_sources)
85
-
86
- return meta_data.__str__()
87
-
88
- @tool
89
- def google_search(query: str) -> str:
90
- """Using the google search and collects metadata."""
91
- # return "LangChain"
92
- global all_sources
93
-
94
- x = SerpAPIWrapper()
95
- search_results:dict = x.results(query)
96
-
97
-
98
- organic_source = search_results['organic_results']
99
- # formatted_string = "Title: {title}, link: {link}, snippet: {snippet}".format(**organic_source)
100
- cleaner_sources = ["Title: {title}, link: {link}, snippet: {snippet}".format(**i) for i in organic_source]
101
-
102
- all_sources += cleaner_sources
103
-
104
- return cleaner_sources.__str__()
105
- # return organic_source
106
-
107
- @tool
108
- def get_arxiv_paper(paper_id:str) -> None:
109
- """Download a paper from axriv to download a paper please input
110
- the axriv id such as "1605.08386v1" This tool is named get_arxiv_paper
111
- If you input "http://arxiv.org/abs/2312.02813", This will break the code. Also only do
112
- "2312.02813". In addition please download one paper at a time. Pleaase keep the inputs/output
113
- free of additional information only have the id.
114
- """
115
- # code from https://lukasschwab.me/arxiv.py/arxiv.html
116
- paper = next(arxiv.Client().results(arxiv.Search(id_list=[paper_id])))
117
-
118
- number_without_period = paper_id.replace('.', '')
119
-
120
- # Download the PDF to a specified directory with a custom filename.
121
- paper.download_pdf(dirpath="./mydir", filename=f"{number_without_period}.pdf")
122
-
123
 
124
  tools = [
125
  arxiv_search,
@@ -127,22 +42,6 @@ tools = [
127
  get_arxiv_paper,
128
  ]
129
 
130
- # tools = [
131
- # create_retriever_tool(
132
- # retriever,
133
- # "search arxiv's database for",
134
- # "Use this to recomend the user a paper to read Unless stated please choose the most recent models",
135
- # # "Searches and returns excerpts from the 2022 State of the Union.",
136
- # ),
137
-
138
- # Tool(
139
- # name="SerpAPI",
140
- # description="A low-cost Google Search API. Useful for when you need to answer questions about current events. Input should be a search query.",
141
- # func=SerpAPIWrapper().run,
142
- # )
143
-
144
- # ]
145
-
146
 
147
  prompt = hub.pull("hwchase17/react-json")
148
  prompt = prompt.partial(
@@ -150,9 +49,9 @@ prompt = prompt.partial(
150
  tool_names=", ".join([t.name for t in tools]),
151
  )
152
 
153
- chat_model = llm
154
  # define the agent
155
- chat_model_with_stop = chat_model.bind(stop=["\nObservation"])
156
  agent = (
157
  {
158
  "input": lambda x: x["input"],
@@ -176,7 +75,6 @@ agent_executor = AgentExecutor(
176
  if __name__ == "__main__":
177
 
178
  # global variable for collecting sources
179
- all_sources = []
180
 
181
  input = agent_executor.invoke(
182
  {
 
1
  # LangChain supports many other chat models. Here, we're using Ollama
2
  from langchain_community.chat_models import ChatOllama
 
3
  from langchain_core.prompts import ChatPromptTemplate
 
 
 
 
4
  from langchain import hub
5
+ from langchain.agents import AgentExecutor
6
  from langchain.agents.format_scratchpad import format_log_to_str
7
  from langchain.agents.output_parsers import (
8
  ReActJsonSingleInputOutputParser,
9
  )
10
  # Import things that are needed generically
 
 
11
  from typing import List, Dict
 
12
  from langchain.tools.render import render_text_description
13
  import os
 
 
14
  import dotenv
15
+ from innovation_pathfinder_ai.structured_tools.structured_tools import (
16
+ arxiv_search, get_arxiv_paper, google_search
17
+ )
18
+
19
+ # hacky and should be replaced with a database
20
+ from innovation_pathfinder_ai.source_container.container import (
21
+ all_sources
22
+ )
23
+
24
 
25
  dotenv.load_dotenv()
26
 
 
27
  OLLMA_BASE_URL = os.getenv("OLLMA_BASE_URL")
28
 
29
 
 
35
  )
36
  prompt = ChatPromptTemplate.from_template("Tell me a short joke about {topic}")
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  tools = [
40
  arxiv_search,
 
42
  get_arxiv_paper,
43
  ]
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  prompt = hub.pull("hwchase17/react-json")
47
  prompt = prompt.partial(
 
49
  tool_names=", ".join([t.name for t in tools]),
50
  )
51
 
52
+
53
  # define the agent
54
+ chat_model_with_stop = llm.bind(stop=["\nObservation"])
55
  agent = (
56
  {
57
  "input": lambda x: x["input"],
 
75
  if __name__ == "__main__":
76
 
77
  # global variable for collecting sources
 
78
 
79
  input = agent_executor.invoke(
80
  {