Asaad Almutareb commited on
Commit
631f6af
1 Parent(s): a7901c4

moved ollama agent to it's own folder

Browse files

added mixtral agent
switched from serpapi to googleSearchAPI
added sources urls to final output

.gitignore CHANGED
@@ -164,6 +164,7 @@ cython_debug/
164
  *.bin
165
  *.pickle
166
  chroma_db/*
 
167
  bin
168
  obj
169
  .langchain.sqlite
 
164
  *.bin
165
  *.pickle
166
  chroma_db/*
167
+ downloaded_papers/*
168
  bin
169
  obj
170
  .langchain.sqlite
hf_mixtral_agent.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HF libraries
2
+ from langchain_community.llms import HuggingFaceHub
3
+ from langchain_core.prompts import ChatPromptTemplate
4
+ from langchain import hub
5
+ import gradio as gr
6
+ from langchain.agents import AgentExecutor
7
+ from langchain.agents.format_scratchpad import format_log_to_str
8
+ from langchain.agents.output_parsers import (
9
+ ReActJsonSingleInputOutputParser,
10
+ )
11
+ # Import things that are needed generically
12
+ from typing import List, Dict
13
+ from langchain.tools.render import render_text_description
14
+ import os
15
+ from dotenv import load_dotenv
16
+ from innovation_pathfinder_ai.structured_tools.structured_tools import (
17
+ arxiv_search, get_arxiv_paper, google_search, wikipedia_search
18
+ )
19
+
20
+ # hacky and should be replaced with a database
21
+ from innovation_pathfinder_ai.source_container.container import (
22
+ all_sources
23
+ )
24
+
25
+ # from langchain_community.chat_message_histories import ChatMessageHistory
26
+ # from langchain_core.runnables.history import RunnableWithMessageHistory
27
+
28
+ # message_history = ChatMessageHistory()
29
+ config = load_dotenv(".env")
30
+ HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
31
+ GOOGLE_CSE_ID = os.getenv('GOOGLE_CSE_ID')
32
+ GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
33
+ LANGCHAIN_TRACING_V2 = "true"
34
+ LANGCHAIN_ENDPOINT = "https://api.smith.langchain.com"
35
+ LANGCHAIN_API_KEY = os.getenv('LANGCHAIN_API_KEY')
36
+ LANGCHAIN_PROJECT = os.getenv('LANGCHAIN_PROJECT')
37
+
38
+ # Load the model from the Hugging Face Hub
39
+ llm = HuggingFaceHub(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", model_kwargs={
40
+ "temperature":0.1,
41
+ "max_new_tokens":1024,
42
+ "repetition_penalty":1.2,
43
+ "return_full_text":False
44
+ })
45
+
46
+
47
+ tools = [
48
+ arxiv_search,
49
+ wikipedia_search,
50
+ google_search,
51
+ # get_arxiv_paper,
52
+ ]
53
+
54
+
55
+ prompt = hub.pull("hwchase17/react-json")
56
+ prompt = prompt.partial(
57
+ tools=render_text_description(tools),
58
+ tool_names=", ".join([t.name for t in tools]),
59
+ )
60
+
61
+
62
+ # define the agent
63
+ chat_model_with_stop = llm.bind(stop=["\nObservation"])
64
+ agent = (
65
+ {
66
+ "input": lambda x: x["input"],
67
+ "agent_scratchpad": lambda x: format_log_to_str(x["intermediate_steps"]),
68
+ }
69
+ | prompt
70
+ | chat_model_with_stop
71
+ | ReActJsonSingleInputOutputParser()
72
+ )
73
+
74
+ # instantiate AgentExecutor
75
+ agent_executor = AgentExecutor(
76
+ agent=agent,
77
+ tools=tools,
78
+ verbose=True,
79
+ max_iterations=10, # cap number of iterations
80
+ #max_execution_time=60, # timout at 60 sec
81
+ return_intermediate_steps=True,
82
+ handle_parsing_errors=True,
83
+ )
84
+
85
+
86
+ if __name__ == "__main__":
87
+
88
+ def add_text(history, text):
89
+ history = history + [(text, None)]
90
+ return history, ""
91
+
92
+ def bot(history):
93
+ response = infer(history[-1][0], history)
94
+ sources = collect_urls(all_sources)
95
+ src_list = '\n'.join(sources)
96
+ response_w_sources = response['output']+"\n\n\n Sources: \n\n\n"+src_list
97
+ history[-1][1] = response_w_sources
98
+ return history
99
+
100
+ def infer(question, history):
101
+ query = question
102
+ result = agent_executor.invoke(
103
+ {
104
+ "input": question,
105
+ }
106
+ )
107
+ return result
108
+
109
+ def vote(data: gr.LikeData):
110
+ if data.liked:
111
+ print("You upvoted this response: " + data.value)
112
+ else:
113
+ print("You downvoted this response: " + data.value)
114
+
115
+ def collect_urls(data_list):
116
+ urls = []
117
+ for item in data_list:
118
+ # Check if item is a string and contains 'link:'
119
+ if isinstance(item, str) and 'link:' in item:
120
+ start = item.find('link:') + len('link: ')
121
+ end = item.find(',', start)
122
+ url = item[start:end if end != -1 else None].strip()
123
+ urls.append(url)
124
+ # Check if item is a dictionary and has 'Entry ID'
125
+ elif isinstance(item, dict) and 'Entry ID' in item:
126
+ urls.append(item['Entry ID'])
127
+ return urls
128
+
129
+ css="""
130
+ #col-container {max-width: 700px; margin-left: auto; margin-right: auto;}
131
+ """
132
+
133
+ title = """
134
+ <div style="text-align: center;max-width: 700px;">
135
+ <p>Hello Dave, how can I help today?<br />
136
+ </div>
137
+ """
138
+
139
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
140
+ with gr.Tab("Google|Wikipedia|Arxiv"):
141
+ with gr.Column(elem_id="col-container"):
142
+ gr.HTML(title)
143
+ with gr.Row():
144
+ question = gr.Textbox(label="Question", placeholder="Type your question and hit Enter ")
145
+ chatbot = gr.Chatbot([], elem_id="chatbot")
146
+ chatbot.like(vote, None, None)
147
+ clear = gr.Button("Clear")
148
+ question.submit(add_text, [chatbot, question], [chatbot, question], queue=False).then(
149
+ bot, chatbot, chatbot
150
+ )
151
+ clear.click(lambda: None, None, chatbot, queue=False)
152
+
153
+ demo.queue()
154
+ demo.launch(debug=True)
155
+
156
+
157
+ x = 0 # for debugging purposes
mixtral_agent.py → innovation_pathfinder_ai/agent_ollama/ollama_mixtral_agent.py RENAMED
File without changes
innovation_pathfinder_ai/agent_ollama/requirements.txt ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain
2
+ langchain-community
3
+ beautifulsoup4
4
+ faiss-cpu
5
+ chromadb
6
+ validators
7
+ sentence_transformers
8
+ typing-extensions==4.8.0
9
+ unstructured
10
+ gradio==3.48.0
11
+ boto3
12
+ aiofiles==23.2.1
13
+ aiohttp==3.9.3
14
+ aiosignal==1.3.1
15
+ altair==5.2.0
16
+ annotated-types==0.6.0
17
+ anyio==4.2.0
18
+ arxiv==2.1.0
19
+ asgiref==3.7.2
20
+ async-timeout==4.0.3
21
+ attrs==23.2.0
22
+ backoff==2.2.1
23
+ bcrypt==4.1.2
24
+ beautifulsoup4==4.12.3
25
+ boto3==1.34.44
26
+ botocore==1.34.44
27
+ build==1.0.3
28
+ cachetools==5.3.2
29
+ certifi==2024.2.2
30
+ chardet==5.2.0
31
+ charset-normalizer==3.3.2
32
+ chroma-hnswlib==0.7.3
33
+ chromadb==0.4.22
34
+ click==8.1.7
35
+ coloredlogs==15.0.1
36
+ contourpy==1.2.0
37
+ cycler==0.12.1
38
+ dataclasses-json==0.6.4
39
+ dataclasses-json-speakeasy==0.5.11
40
+ Deprecated==1.2.14
41
+ emoji==2.10.1
42
+ exceptiongroup==1.2.0
43
+ faiss-cpu==1.7.4
44
+ fastapi==0.109.2
45
+ feedparser==6.0.10
46
+ ffmpy==0.3.2
47
+ filelock==3.13.1
48
+ filetype==1.2.0
49
+ flatbuffers==23.5.26
50
+ fonttools==4.49.0
51
+ frozenlist==1.4.1
52
+ fsspec==2024.2.0
53
+ google-auth==2.28.0
54
+ google-search-results==2.4.2
55
+ googleapis-common-protos==1.62.0
56
+ gradio==3.48.0
57
+ gradio_client==0.6.1
58
+ greenlet==3.0.3
59
+ grpcio==1.60.1
60
+ h11==0.14.0
61
+ httpcore==1.0.3
62
+ httptools==0.6.1
63
+ httpx==0.26.0
64
+ huggingface-hub==0.20.3
65
+ humanfriendly==10.0
66
+ idna==3.6
67
+ importlib-metadata==6.11.0
68
+ importlib-resources==6.1.1
69
+ Jinja2==3.1.3
70
+ jmespath==1.0.1
71
+ joblib==1.3.2
72
+ jsonpatch==1.33
73
+ jsonpath-python==1.0.6
74
+ jsonpointer==2.4
75
+ jsonschema==4.21.1
76
+ jsonschema-specifications==2023.12.1
77
+ kiwisolver==1.4.5
78
+ kubernetes==29.0.0
79
+ langchain==0.1.7
80
+ langchain-community==0.0.20
81
+ langchain-core==0.1.23
82
+ langchainhub==0.1.14
83
+ langdetect==1.0.9
84
+ langsmith==0.0.87
85
+ lxml==5.1.0
86
+ MarkupSafe==2.1.5
87
+ marshmallow==3.20.2
88
+ matplotlib==3.8.3
89
+ mmh3==4.1.0
90
+ monotonic==1.6
91
+ mpmath==1.3.0
92
+ multidict==6.0.5
93
+ mypy-extensions==1.0.0
94
+ networkx==3.2.1
95
+ nltk==3.8.1
96
+ numpy==1.26.4
97
+ nvidia-cublas-cu12==12.1.3.1
98
+ nvidia-cuda-cupti-cu12==12.1.105
99
+ nvidia-cuda-nvrtc-cu12==12.1.105
100
+ nvidia-cuda-runtime-cu12==12.1.105
101
+ nvidia-cudnn-cu12==8.9.2.26
102
+ nvidia-cufft-cu12==11.0.2.54
103
+ nvidia-curand-cu12==10.3.2.106
104
+ nvidia-cusolver-cu12==11.4.5.107
105
+ nvidia-cusparse-cu12==12.1.0.106
106
+ nvidia-nccl-cu12==2.19.3
107
+ nvidia-nvjitlink-cu12==12.3.101
108
+ nvidia-nvtx-cu12==12.1.105
109
+ oauthlib==3.2.2
110
+ onnxruntime==1.17.0
111
+ opentelemetry-api==1.22.0
112
+ opentelemetry-exporter-otlp-proto-common==1.22.0
113
+ opentelemetry-exporter-otlp-proto-grpc==1.22.0
114
+ opentelemetry-instrumentation==0.43b0
115
+ opentelemetry-instrumentation-asgi==0.43b0
116
+ opentelemetry-instrumentation-fastapi==0.43b0
117
+ opentelemetry-proto==1.22.0
118
+ opentelemetry-sdk==1.22.0
119
+ opentelemetry-semantic-conventions==0.43b0
120
+ opentelemetry-util-http==0.43b0
121
+ orjson==3.9.14
122
+ overrides==7.7.0
123
+ packaging==23.2
124
+ pandas==2.2.0
125
+ pillow==10.2.0
126
+ posthog==3.4.1
127
+ protobuf==4.25.3
128
+ pulsar-client==3.4.0
129
+ pyasn1==0.5.1
130
+ pyasn1-modules==0.3.0
131
+ pydantic==2.6.1
132
+ pydantic_core==2.16.2
133
+ pydub==0.25.1
134
+ pyparsing==3.1.1
135
+ PyPika==0.48.9
136
+ pyproject_hooks==1.0.0
137
+ python-dateutil==2.8.2
138
+ python-dotenv==1.0.1
139
+ python-iso639==2024.2.7
140
+ python-magic==0.4.27
141
+ python-multipart==0.0.9
142
+ pytz==2024.1
143
+ PyYAML==6.0.1
144
+ rapidfuzz==3.6.1
145
+ referencing==0.33.0
146
+ regex==2023.12.25
147
+ requests==2.31.0
148
+ requests-oauthlib==1.3.1
149
+ rpds-py==0.18.0
150
+ rsa==4.9
151
+ s3transfer==0.10.0
152
+ safetensors==0.4.2
153
+ scikit-learn==1.4.1.post1
154
+ scipy==1.12.0
155
+ semantic-version==2.10.0
156
+ sentence-transformers==2.3.1
157
+ sentencepiece==0.1.99
158
+ sgmllib3k==1.0.0
159
+ six==1.16.0
160
+ sniffio==1.3.0
161
+ soupsieve==2.5
162
+ SQLAlchemy==2.0.27
163
+ starlette==0.36.3
164
+ sympy==1.12
165
+ tabulate==0.9.0
166
+ tenacity==8.2.3
167
+ threadpoolctl==3.3.0
168
+ tokenizers==0.15.2
169
+ tomli==2.0.1
170
+ toolz==0.12.1
171
+ torch==2.2.0
172
+ tqdm==4.66.2
173
+ transformers==4.37.2
174
+ triton==2.2.0
175
+ typer==0.9.0
176
+ types-requests==2.31.0.20240125
177
+ typing-inspect==0.9.0
178
+ typing_extensions==4.8.0
179
+ tzdata==2024.1
180
+ unstructured==0.12.4
181
+ unstructured-client==0.18.0
182
+ urllib3==2.0.7
183
+ uvicorn==0.27.1
184
+ uvloop==0.19.0
185
+ validators==0.22.0
186
+ watchfiles==0.21.0
187
+ websocket-client==1.7.0
188
+ websockets==11.0.3
189
+ wrapt==1.16.0
190
+ yarl==1.9.4
191
+ zipp==3.17.0
192
+ aiofiles==23.2.1
193
+ aiohttp==3.9.3
194
+ aiosignal==1.3.1
195
+ altair==5.2.0
196
+ annotated-types==0.6.0
197
+ anyio==4.2.0
198
+ arxiv==2.1.0
199
+ asgiref==3.7.2
200
+ async-timeout==4.0.3
201
+ attrs==23.2.0
202
+ backoff==2.2.1
203
+ bcrypt==4.1.2
204
+ beautifulsoup4==4.12.3
205
+ boto3==1.34.42
206
+ botocore==1.34.42
207
+ build==1.0.3
208
+ cachetools==5.3.2
209
+ certifi==2024.2.2
210
+ chardet==5.2.0
211
+ charset-normalizer==3.3.2
212
+ chroma-hnswlib==0.7.3
213
+ chromadb==0.4.22
214
+ click==8.1.7
215
+ coloredlogs==15.0.1
216
+ contourpy==1.2.0
217
+ cycler==0.12.1
218
+ dataclasses-json==0.6.4
219
+ dataclasses-json-speakeasy==0.5.11
220
+ Deprecated==1.2.14
221
+ emoji==2.10.1
222
+ exceptiongroup==1.2.0
223
+ faiss-cpu==1.7.4
224
+ fastapi==0.109.2
225
+ feedparser==6.0.10
226
+ ffmpy==0.3.2
227
+ filelock==3.13.1
228
+ filetype==1.2.0
229
+ flatbuffers==23.5.26
230
+ fonttools==4.48.1
231
+ frozenlist==1.4.1
232
+ fsspec==2024.2.0
233
+ gitdb==4.0.11
234
+ GitPython==3.1.41
235
+ google-auth==2.27.0
236
+ google_search_results==2.4.2
237
+ googleapis-common-protos==1.62.0
238
+ gradio==3.48.0
239
+ gradio_client==0.6.1
240
+ greenlet==3.0.3
241
+ grpcio==1.60.1
242
+ h11==0.14.0
243
+ httpcore==1.0.3
244
+ httptools==0.6.1
245
+ httpx==0.26.0
246
+ huggingface-hub==0.20.3
247
+ humanfriendly==10.0
248
+ idna==3.6
249
+ importlib-metadata==6.11.0
250
+ importlib-resources==6.1.1
251
+ Jinja2==3.1.3
252
+ jmespath==1.0.1
253
+ joblib==1.3.2
254
+ jsonpatch==1.33
255
+ jsonpath-python==1.0.6
256
+ jsonpointer==2.4
257
+ jsonschema==4.21.1
258
+ jsonschema-specifications==2023.12.1
259
+ kiwisolver==1.4.5
260
+ kubernetes==29.0.0
261
+ langchain==0.1.7
262
+ langchain-community==0.0.20
263
+ langchain-core==0.1.23
264
+ langchainhub==0.1.14
265
+ langdetect==1.0.9
266
+ langsmith==0.0.87
267
+ lxml==5.1.0
268
+ MarkupSafe==2.1.5
269
+ marshmallow==3.20.2
270
+ matplotlib==3.8.3
271
+ mmh3==4.1.0
272
+ monotonic==1.6
273
+ mpmath==1.3.0
274
+ multidict==6.0.5
275
+ mypy-extensions==1.0.0
276
+ networkx==3.2.1
277
+ nltk==3.8.1
278
+ numpy==1.26.4
279
+ nvidia-cublas-cu12==12.1.3.1
280
+ nvidia-cuda-cupti-cu12==12.1.105
281
+ nvidia-cuda-nvrtc-cu12==12.1.105
282
+ nvidia-cuda-runtime-cu12==12.1.105
283
+ nvidia-cudnn-cu12==8.9.2.26
284
+ nvidia-cufft-cu12==11.0.2.54
285
+ nvidia-curand-cu12==10.3.2.106
286
+ nvidia-cusolver-cu12==11.4.5.107
287
+ nvidia-cusparse-cu12==12.1.0.106
288
+ nvidia-nccl-cu12==2.19.3
289
+ nvidia-nvjitlink-cu12==12.3.101
290
+ nvidia-nvtx-cu12==12.1.105
291
+ oauthlib==3.2.2
292
+ onnxruntime==1.17.0
293
+ opentelemetry-api==1.22.0
294
+ opentelemetry-exporter-otlp-proto-common==1.22.0
295
+ opentelemetry-exporter-otlp-proto-grpc==1.22.0
296
+ opentelemetry-instrumentation==0.43b0
297
+ opentelemetry-instrumentation-asgi==0.43b0
298
+ opentelemetry-instrumentation-fastapi==0.43b0
299
+ opentelemetry-proto==1.22.0
300
+ opentelemetry-sdk==1.22.0
301
+ opentelemetry-semantic-conventions==0.43b0
302
+ opentelemetry-util-http==0.43b0
303
+ orjson==3.9.14
304
+ overrides==7.7.0
305
+ packaging==23.2
306
+ pandas==2.2.0
307
+ pillow==10.2.0
308
+ posthog==3.4.1
309
+ protobuf==4.25.2
310
+ pulsar-client==3.4.0
311
+ pyasn1==0.5.1
312
+ pyasn1-modules==0.3.0
313
+ pydantic==2.6.1
314
+ pydantic_core==2.16.2
315
+ pydub==0.25.1
316
+ pyparsing==3.1.1
317
+ PyPika==0.48.9
318
+ pyproject_hooks==1.0.0
319
+ python-dateutil==2.8.2
320
+ python-dotenv==1.0.1
321
+ python-iso639==2024.2.7
322
+ python-magic==0.4.27
323
+ python-multipart==0.0.9
324
+ pytz==2024.1
325
+ PyYAML==6.0.1
326
+ rapidfuzz==3.6.1
327
+ referencing==0.33.0
328
+ regex==2023.12.25
329
+ requests==2.31.0
330
+ requests-oauthlib==1.3.1
331
+ rpds-py==0.18.0
332
+ rsa==4.9
333
+ s3transfer==0.10.0
334
+ safetensors==0.4.2
335
+ scikit-learn==1.4.0
336
+ scipy==1.12.0
337
+ semantic-version==2.10.0
338
+ sentence-transformers==2.3.1
339
+ sentencepiece==0.1.99
340
+ sgmllib3k==1.0.0
341
+ six==1.16.0
342
+ smmap==5.0.1
343
+ sniffio==1.3.0
344
+ soupsieve==2.5
345
+ SQLAlchemy==2.0.27
346
+ starlette==0.36.3
347
+ sympy==1.12
348
+ tabulate==0.9.0
349
+ tenacity==8.2.3
350
+ threadpoolctl==3.3.0
351
+ tokenizers==0.15.2
352
+ tomli==2.0.1
353
+ toolz==0.12.1
354
+ torch==2.2.0
355
+ tqdm==4.66.2
356
+ transformers==4.37.2
357
+ triton==2.2.0
358
+ typer==0.9.0
359
+ types-requests==2.31.0.20240125
360
+ typing-inspect==0.9.0
361
+ typing_extensions==4.8.0
362
+ tzdata==2024.1
363
+ unstructured==0.12.4
364
+ unstructured-client==0.18.0
365
+ urllib3==2.0.7
366
+ uvicorn==0.27.1
367
+ uvloop==0.19.0
368
+ validators==0.22.0
369
+ watchfiles==0.21.0
370
+ websocket-client==1.7.0
371
+ websockets==11.0.3
372
+ wrapt==1.16.0
373
+ yarl==1.9.4
374
+ zipp==3.17.0
innovation_pathfinder_ai/structured_tools/structured_tools.py CHANGED
@@ -1,6 +1,10 @@
1
  from langchain.tools import BaseTool, StructuredTool, tool
2
- from langchain.retrievers import ArxivRetriever
3
- from langchain_community.utilities import SerpAPIWrapper
 
 
 
 
4
  import arxiv
5
 
6
  # hacky and should be replaced with a database
@@ -10,7 +14,8 @@ from innovation_pathfinder_ai.source_container.container import (
10
 
11
  @tool
12
  def arxiv_search(query: str) -> str:
13
- """Using the arxiv search and collects metadata."""
 
14
  # return "LangChain"
15
  global all_sources
16
  arxiv_retriever = ArxivRetriever(load_max_docs=2)
@@ -40,23 +45,34 @@ def get_arxiv_paper(paper_id:str) -> None:
40
  number_without_period = paper_id.replace('.', '')
41
 
42
  # Download the PDF to a specified directory with a custom filename.
43
- paper.download_pdf(dirpath="./mydir", filename=f"{number_without_period}.pdf")
44
 
45
 
46
  @tool
47
  def google_search(query: str) -> str:
48
- """Using the google search and collects metadata."""
49
  # return "LangChain"
50
  global all_sources
51
 
52
- x = SerpAPIWrapper()
53
- search_results:dict = x.results(query)
54
 
55
 
56
- organic_source = search_results['organic_results']
57
  # formatted_string = "Title: {title}, link: {link}, snippet: {snippet}".format(**organic_source)
58
- cleaner_sources = ["Title: {title}, link: {link}, snippet: {snippet}".format(**i) for i in organic_source]
59
 
60
  all_sources += cleaner_sources
61
 
62
- return cleaner_sources.__str__()
 
 
 
 
 
 
 
 
 
 
 
 
1
  from langchain.tools import BaseTool, StructuredTool, tool
2
+ from langchain_community.retrievers import ArxivRetriever
3
+ #from langchain_community.utilities import SerpAPIWrapper
4
+ from langchain_community.tools import WikipediaQueryRun
5
+ from langchain_community.utilities import WikipediaAPIWrapper
6
+ #from langchain.tools import Tool
7
+ from langchain_community.utilities import GoogleSearchAPIWrapper
8
  import arxiv
9
 
10
  # hacky and should be replaced with a database
 
14
 
15
  @tool
16
  def arxiv_search(query: str) -> str:
17
+ """Search arxiv database for scientific research papers and studies. This is your primary information source.
18
+ always check it first when you search for information, before using any other tool."""
19
  # return "LangChain"
20
  global all_sources
21
  arxiv_retriever = ArxivRetriever(load_max_docs=2)
 
45
  number_without_period = paper_id.replace('.', '')
46
 
47
  # Download the PDF to a specified directory with a custom filename.
48
+ paper.download_pdf(dirpath="./downloaded_papers", filename=f"{number_without_period}.pdf")
49
 
50
 
51
  @tool
52
  def google_search(query: str) -> str:
53
+ """Search Google for additional results when you can't answer questions using arxiv search or wikipedia search."""
54
  # return "LangChain"
55
  global all_sources
56
 
57
+ websearch = GoogleSearchAPIWrapper()
58
+ search_results:dict = websearch.results(query, 5)
59
 
60
 
61
+ #organic_source = search_results['organic_results']
62
  # formatted_string = "Title: {title}, link: {link}, snippet: {snippet}".format(**organic_source)
63
+ cleaner_sources = ["Title: {title}, link: {link}, snippet: {snippet}".format(**i) for i in search_results]
64
 
65
  all_sources += cleaner_sources
66
 
67
+ return cleaner_sources.__str__()
68
+
69
+ @tool
70
+ def wikipedia_search(query: str) -> str:
71
+ """Search Wikipedia for additional information to expand on research papers or when no papers can be found."""
72
+ global all_sources
73
+
74
+ api_wrapper = WikipediaAPIWrapper()
75
+ wikipedia_search = WikipediaQueryRun(api_wrapper=api_wrapper)
76
+ wikipedia_results = wikipedia_search.run(query)
77
+ all_sources += wikipedia_results
78
+ return wikipedia_results
requirements.txt CHANGED
@@ -1,374 +1,10 @@
1
- langchain
2
- langchain-community
3
- beautifulsoup4
4
- faiss-cpu
5
- chromadb
6
- validators
7
- sentence_transformers
8
- typing-extensions==4.8.0
9
- unstructured
10
  gradio==3.48.0
11
- boto3
12
- aiofiles==23.2.1
13
- aiohttp==3.9.3
14
- aiosignal==1.3.1
15
- altair==5.2.0
16
- annotated-types==0.6.0
17
- anyio==4.2.0
18
- arxiv==2.1.0
19
- asgiref==3.7.2
20
- async-timeout==4.0.3
21
- attrs==23.2.0
22
- backoff==2.2.1
23
- bcrypt==4.1.2
24
- beautifulsoup4==4.12.3
25
- boto3==1.34.44
26
- botocore==1.34.44
27
- build==1.0.3
28
- cachetools==5.3.2
29
- certifi==2024.2.2
30
- chardet==5.2.0
31
- charset-normalizer==3.3.2
32
- chroma-hnswlib==0.7.3
33
- chromadb==0.4.22
34
- click==8.1.7
35
- coloredlogs==15.0.1
36
- contourpy==1.2.0
37
- cycler==0.12.1
38
- dataclasses-json==0.6.4
39
- dataclasses-json-speakeasy==0.5.11
40
- Deprecated==1.2.14
41
- emoji==2.10.1
42
- exceptiongroup==1.2.0
43
- faiss-cpu==1.7.4
44
- fastapi==0.109.2
45
- feedparser==6.0.10
46
- ffmpy==0.3.2
47
- filelock==3.13.1
48
- filetype==1.2.0
49
- flatbuffers==23.5.26
50
- fonttools==4.49.0
51
- frozenlist==1.4.1
52
- fsspec==2024.2.0
53
- google-auth==2.28.0
54
- google-search-results==2.4.2
55
- googleapis-common-protos==1.62.0
56
- gradio==3.48.0
57
- gradio_client==0.6.1
58
- greenlet==3.0.3
59
- grpcio==1.60.1
60
- h11==0.14.0
61
- httpcore==1.0.3
62
- httptools==0.6.1
63
- httpx==0.26.0
64
- huggingface-hub==0.20.3
65
- humanfriendly==10.0
66
- idna==3.6
67
- importlib-metadata==6.11.0
68
- importlib-resources==6.1.1
69
- Jinja2==3.1.3
70
- jmespath==1.0.1
71
- joblib==1.3.2
72
- jsonpatch==1.33
73
- jsonpath-python==1.0.6
74
- jsonpointer==2.4
75
- jsonschema==4.21.1
76
- jsonschema-specifications==2023.12.1
77
- kiwisolver==1.4.5
78
- kubernetes==29.0.0
79
- langchain==0.1.7
80
- langchain-community==0.0.20
81
- langchain-core==0.1.23
82
- langchainhub==0.1.14
83
- langdetect==1.0.9
84
- langsmith==0.0.87
85
- lxml==5.1.0
86
- MarkupSafe==2.1.5
87
- marshmallow==3.20.2
88
- matplotlib==3.8.3
89
- mmh3==4.1.0
90
- monotonic==1.6
91
- mpmath==1.3.0
92
- multidict==6.0.5
93
- mypy-extensions==1.0.0
94
- networkx==3.2.1
95
- nltk==3.8.1
96
- numpy==1.26.4
97
- nvidia-cublas-cu12==12.1.3.1
98
- nvidia-cuda-cupti-cu12==12.1.105
99
- nvidia-cuda-nvrtc-cu12==12.1.105
100
- nvidia-cuda-runtime-cu12==12.1.105
101
- nvidia-cudnn-cu12==8.9.2.26
102
- nvidia-cufft-cu12==11.0.2.54
103
- nvidia-curand-cu12==10.3.2.106
104
- nvidia-cusolver-cu12==11.4.5.107
105
- nvidia-cusparse-cu12==12.1.0.106
106
- nvidia-nccl-cu12==2.19.3
107
- nvidia-nvjitlink-cu12==12.3.101
108
- nvidia-nvtx-cu12==12.1.105
109
- oauthlib==3.2.2
110
- onnxruntime==1.17.0
111
- opentelemetry-api==1.22.0
112
- opentelemetry-exporter-otlp-proto-common==1.22.0
113
- opentelemetry-exporter-otlp-proto-grpc==1.22.0
114
- opentelemetry-instrumentation==0.43b0
115
- opentelemetry-instrumentation-asgi==0.43b0
116
- opentelemetry-instrumentation-fastapi==0.43b0
117
- opentelemetry-proto==1.22.0
118
- opentelemetry-sdk==1.22.0
119
- opentelemetry-semantic-conventions==0.43b0
120
- opentelemetry-util-http==0.43b0
121
- orjson==3.9.14
122
- overrides==7.7.0
123
- packaging==23.2
124
- pandas==2.2.0
125
- pillow==10.2.0
126
- posthog==3.4.1
127
- protobuf==4.25.3
128
- pulsar-client==3.4.0
129
- pyasn1==0.5.1
130
- pyasn1-modules==0.3.0
131
- pydantic==2.6.1
132
- pydantic_core==2.16.2
133
- pydub==0.25.1
134
- pyparsing==3.1.1
135
- PyPika==0.48.9
136
- pyproject_hooks==1.0.0
137
- python-dateutil==2.8.2
138
- python-dotenv==1.0.1
139
- python-iso639==2024.2.7
140
- python-magic==0.4.27
141
- python-multipart==0.0.9
142
- pytz==2024.1
143
- PyYAML==6.0.1
144
- rapidfuzz==3.6.1
145
- referencing==0.33.0
146
- regex==2023.12.25
147
- requests==2.31.0
148
- requests-oauthlib==1.3.1
149
- rpds-py==0.18.0
150
- rsa==4.9
151
- s3transfer==0.10.0
152
- safetensors==0.4.2
153
- scikit-learn==1.4.1.post1
154
- scipy==1.12.0
155
- semantic-version==2.10.0
156
- sentence-transformers==2.3.1
157
- sentencepiece==0.1.99
158
- sgmllib3k==1.0.0
159
- six==1.16.0
160
- sniffio==1.3.0
161
- soupsieve==2.5
162
- SQLAlchemy==2.0.27
163
- starlette==0.36.3
164
- sympy==1.12
165
- tabulate==0.9.0
166
- tenacity==8.2.3
167
- threadpoolctl==3.3.0
168
- tokenizers==0.15.2
169
- tomli==2.0.1
170
- toolz==0.12.1
171
- torch==2.2.0
172
- tqdm==4.66.2
173
- transformers==4.37.2
174
- triton==2.2.0
175
- typer==0.9.0
176
- types-requests==2.31.0.20240125
177
- typing-inspect==0.9.0
178
- typing_extensions==4.8.0
179
- tzdata==2024.1
180
- unstructured==0.12.4
181
- unstructured-client==0.18.0
182
- urllib3==2.0.7
183
- uvicorn==0.27.1
184
- uvloop==0.19.0
185
- validators==0.22.0
186
- watchfiles==0.21.0
187
- websocket-client==1.7.0
188
- websockets==11.0.3
189
- wrapt==1.16.0
190
- yarl==1.9.4
191
- zipp==3.17.0
192
- aiofiles==23.2.1
193
- aiohttp==3.9.3
194
- aiosignal==1.3.1
195
- altair==5.2.0
196
- annotated-types==0.6.0
197
- anyio==4.2.0
198
- arxiv==2.1.0
199
- asgiref==3.7.2
200
- async-timeout==4.0.3
201
- attrs==23.2.0
202
- backoff==2.2.1
203
- bcrypt==4.1.2
204
- beautifulsoup4==4.12.3
205
- boto3==1.34.42
206
- botocore==1.34.42
207
- build==1.0.3
208
- cachetools==5.3.2
209
- certifi==2024.2.2
210
- chardet==5.2.0
211
- charset-normalizer==3.3.2
212
- chroma-hnswlib==0.7.3
213
- chromadb==0.4.22
214
- click==8.1.7
215
- coloredlogs==15.0.1
216
- contourpy==1.2.0
217
- cycler==0.12.1
218
- dataclasses-json==0.6.4
219
- dataclasses-json-speakeasy==0.5.11
220
- Deprecated==1.2.14
221
- emoji==2.10.1
222
- exceptiongroup==1.2.0
223
- faiss-cpu==1.7.4
224
- fastapi==0.109.2
225
- feedparser==6.0.10
226
- ffmpy==0.3.2
227
- filelock==3.13.1
228
- filetype==1.2.0
229
- flatbuffers==23.5.26
230
- fonttools==4.48.1
231
- frozenlist==1.4.1
232
- fsspec==2024.2.0
233
- gitdb==4.0.11
234
- GitPython==3.1.41
235
- google-auth==2.27.0
236
- google_search_results==2.4.2
237
- googleapis-common-protos==1.62.0
238
- gradio==3.48.0
239
- gradio_client==0.6.1
240
- greenlet==3.0.3
241
- grpcio==1.60.1
242
- h11==0.14.0
243
- httpcore==1.0.3
244
- httptools==0.6.1
245
- httpx==0.26.0
246
- huggingface-hub==0.20.3
247
- humanfriendly==10.0
248
- idna==3.6
249
- importlib-metadata==6.11.0
250
- importlib-resources==6.1.1
251
- Jinja2==3.1.3
252
- jmespath==1.0.1
253
- joblib==1.3.2
254
- jsonpatch==1.33
255
- jsonpath-python==1.0.6
256
- jsonpointer==2.4
257
- jsonschema==4.21.1
258
- jsonschema-specifications==2023.12.1
259
- kiwisolver==1.4.5
260
- kubernetes==29.0.0
261
- langchain==0.1.7
262
- langchain-community==0.0.20
263
- langchain-core==0.1.23
264
- langchainhub==0.1.14
265
- langdetect==1.0.9
266
- langsmith==0.0.87
267
- lxml==5.1.0
268
- MarkupSafe==2.1.5
269
- marshmallow==3.20.2
270
- matplotlib==3.8.3
271
- mmh3==4.1.0
272
- monotonic==1.6
273
- mpmath==1.3.0
274
- multidict==6.0.5
275
- mypy-extensions==1.0.0
276
- networkx==3.2.1
277
- nltk==3.8.1
278
- numpy==1.26.4
279
- nvidia-cublas-cu12==12.1.3.1
280
- nvidia-cuda-cupti-cu12==12.1.105
281
- nvidia-cuda-nvrtc-cu12==12.1.105
282
- nvidia-cuda-runtime-cu12==12.1.105
283
- nvidia-cudnn-cu12==8.9.2.26
284
- nvidia-cufft-cu12==11.0.2.54
285
- nvidia-curand-cu12==10.3.2.106
286
- nvidia-cusolver-cu12==11.4.5.107
287
- nvidia-cusparse-cu12==12.1.0.106
288
- nvidia-nccl-cu12==2.19.3
289
- nvidia-nvjitlink-cu12==12.3.101
290
- nvidia-nvtx-cu12==12.1.105
291
- oauthlib==3.2.2
292
- onnxruntime==1.17.0
293
- opentelemetry-api==1.22.0
294
- opentelemetry-exporter-otlp-proto-common==1.22.0
295
- opentelemetry-exporter-otlp-proto-grpc==1.22.0
296
- opentelemetry-instrumentation==0.43b0
297
- opentelemetry-instrumentation-asgi==0.43b0
298
- opentelemetry-instrumentation-fastapi==0.43b0
299
- opentelemetry-proto==1.22.0
300
- opentelemetry-sdk==1.22.0
301
- opentelemetry-semantic-conventions==0.43b0
302
- opentelemetry-util-http==0.43b0
303
- orjson==3.9.14
304
- overrides==7.7.0
305
- packaging==23.2
306
- pandas==2.2.0
307
- pillow==10.2.0
308
- posthog==3.4.1
309
- protobuf==4.25.2
310
- pulsar-client==3.4.0
311
- pyasn1==0.5.1
312
- pyasn1-modules==0.3.0
313
- pydantic==2.6.1
314
- pydantic_core==2.16.2
315
- pydub==0.25.1
316
- pyparsing==3.1.1
317
- PyPika==0.48.9
318
- pyproject_hooks==1.0.0
319
- python-dateutil==2.8.2
320
- python-dotenv==1.0.1
321
- python-iso639==2024.2.7
322
- python-magic==0.4.27
323
- python-multipart==0.0.9
324
- pytz==2024.1
325
- PyYAML==6.0.1
326
- rapidfuzz==3.6.1
327
- referencing==0.33.0
328
- regex==2023.12.25
329
- requests==2.31.0
330
- requests-oauthlib==1.3.1
331
- rpds-py==0.18.0
332
- rsa==4.9
333
- s3transfer==0.10.0
334
- safetensors==0.4.2
335
- scikit-learn==1.4.0
336
- scipy==1.12.0
337
- semantic-version==2.10.0
338
- sentence-transformers==2.3.1
339
- sentencepiece==0.1.99
340
- sgmllib3k==1.0.0
341
- six==1.16.0
342
- smmap==5.0.1
343
- sniffio==1.3.0
344
- soupsieve==2.5
345
- SQLAlchemy==2.0.27
346
- starlette==0.36.3
347
- sympy==1.12
348
- tabulate==0.9.0
349
- tenacity==8.2.3
350
- threadpoolctl==3.3.0
351
- tokenizers==0.15.2
352
- tomli==2.0.1
353
- toolz==0.12.1
354
- torch==2.2.0
355
- tqdm==4.66.2
356
- transformers==4.37.2
357
- triton==2.2.0
358
- typer==0.9.0
359
- types-requests==2.31.0.20240125
360
- typing-inspect==0.9.0
361
- typing_extensions==4.8.0
362
- tzdata==2024.1
363
- unstructured==0.12.4
364
- unstructured-client==0.18.0
365
- urllib3==2.0.7
366
- uvicorn==0.27.1
367
- uvloop==0.19.0
368
- validators==0.22.0
369
- watchfiles==0.21.0
370
- websocket-client==1.7.0
371
- websockets==11.0.3
372
- wrapt==1.16.0
373
- yarl==1.9.4
374
- zipp==3.17.0
 
1
+ langchain-community
2
+ langchain
3
+ google-search-results
4
+ langchainhub
5
+ text_generation
6
+ arxiv
7
+ wikipedia
 
 
8
  gradio==3.48.0
9
+ chromadb
10
+ google_api_python_client