sabazo commited on
Commit
25a2ce0
2 Parent(s): 68b7b58 fcbd089

Merge pull request #2 from almutareb/mixtral_agent

Browse files
Files changed (3) hide show
  1. example.env +8 -0
  2. mixtral_agent.py +117 -0
  3. requirements.txt +191 -11
example.env ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # hugging have API TOKEN
2
+ HUGGINGFACEHUB_API_TOKEN=
3
+
4
+ # url to interface with ollama API
5
+ OLLMA_BASE_URL=
6
+
7
+ # environmental varaibles needed to use tools
8
+ SERPAPI_API_KEY=
mixtral_agent.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LangChain supports many other chat models. Here, we're using Ollama
2
+ from langchain_community.chat_models import ChatOllama
3
+ from langchain_core.output_parsers import StrOutputParser
4
+ from langchain_core.prompts import ChatPromptTemplate
5
+ from langchain.tools.retriever import create_retriever_tool
6
+ from langchain_community.utilities import SerpAPIWrapper
7
+ from langchain.retrievers import ArxivRetriever
8
+ from langchain_core.tools import Tool
9
+ from langchain import hub
10
+ from langchain.agents import AgentExecutor, load_tools
11
+ from langchain.agents.format_scratchpad import format_log_to_str
12
+ from langchain.agents.output_parsers import (
13
+ ReActJsonSingleInputOutputParser,
14
+ )
15
+ from langchain.tools.render import render_text_description
16
+ import os
17
+
18
+ import dotenv
19
+
20
+ dotenv.load_dotenv()
21
+
22
+
23
+ OLLMA_BASE_URL = os.getenv("OLLMA_BASE_URL")
24
+
25
+
26
+ # supports many more optional parameters. Hover on your `ChatOllama(...)`
27
+ # class to view the latest available supported parameters
28
+ llm = ChatOllama(
29
+ model="mistral",
30
+ base_url= OLLMA_BASE_URL
31
+ )
32
+ prompt = ChatPromptTemplate.from_template("Tell me a short joke about {topic}")
33
+
34
+ # using LangChain Expressive Language chain syntax
35
+ # learn more about the LCEL on
36
+ # https://python.langchain.com/docs/expression_language/why
37
+ chain = prompt | llm | StrOutputParser()
38
+
39
+ # for brevity, response is printed in terminal
40
+ # You can use LangServe to deploy your application for
41
+ # production
42
+ print(chain.invoke({"topic": "Space travel"}))
43
+
44
+ retriever = ArxivRetriever(load_max_docs=2)
45
+
46
+ tools = [
47
+ create_retriever_tool(
48
+ retriever,
49
+ "search arxiv's database for",
50
+ "Use this to recomend the user a paper to read Unless stated please choose the most recent models",
51
+ # "Searches and returns excerpts from the 2022 State of the Union.",
52
+ ),
53
+
54
+ Tool(
55
+ name="SerpAPI",
56
+ description="A low-cost Google Search API. Useful for when you need to answer questions about current events. Input should be a search query.",
57
+ func=SerpAPIWrapper().run,
58
+ )
59
+
60
+ ]
61
+
62
+
63
+
64
+ prompt = hub.pull("hwchase17/react-json")
65
+ prompt = prompt.partial(
66
+ tools=render_text_description(tools),
67
+ tool_names=", ".join([t.name for t in tools]),
68
+ )
69
+
70
+ chat_model = llm
71
+ # define the agent
72
+ chat_model_with_stop = chat_model.bind(stop=["\nObservation"])
73
+ agent = (
74
+ {
75
+ "input": lambda x: x["input"],
76
+ "agent_scratchpad": lambda x: format_log_to_str(x["intermediate_steps"]),
77
+ }
78
+ | prompt
79
+ | chat_model_with_stop
80
+ | ReActJsonSingleInputOutputParser()
81
+ )
82
+
83
+ # instantiate AgentExecutor
84
+ agent_executor = AgentExecutor(
85
+ agent=agent,
86
+ tools=tools,
87
+ verbose=True,
88
+ handle_parsing_errors=True #prevents error
89
+ )
90
+
91
+ # agent_executor.invoke(
92
+ # {
93
+ # "input": "Who is the current holder of the speed skating world record on 500 meters? What is her current age raised to the 0.43 power?"
94
+ # }
95
+ # )
96
+
97
+ # agent_executor.invoke(
98
+ # {
99
+ # "input": "what are large language models and why are they so expensive to run?"
100
+ # }
101
+ # )
102
+
103
+ # agent_executor.invoke(
104
+ # {
105
+ # "input": "How to generate videos from images using state of the art macchine learning models"
106
+ # }
107
+ # )
108
+
109
+
110
+ agent_executor.invoke(
111
+ {
112
+ "input": "How to generate videos from images using state of the art macchine learning models; Using the axriv retriever " +
113
+ "add the urls of the papers used in the final answer using the metadata from the retriever"
114
+ # f"Please prioritize the newest papers this is the current data {get_current_date()}"
115
+ }
116
+ )
117
+
requirements.txt CHANGED
@@ -1,11 +1,191 @@
1
- langchain
2
- langchain-community
3
- beautifulsoup4
4
- faiss-cpu
5
- chromadb
6
- validators
7
- sentence_transformers
8
- typing-extensions==4.8.0
9
- unstructured
10
- gradio==3.48.0
11
- boto3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain
2
+ langchain-community
3
+ beautifulsoup4
4
+ faiss-cpu
5
+ chromadb
6
+ validators
7
+ sentence_transformers
8
+ typing-extensions==4.8.0
9
+ unstructured
10
+ gradio==3.48.0
11
+ boto3
12
+ aiofiles==23.2.1
13
+ aiohttp==3.9.3
14
+ aiosignal==1.3.1
15
+ altair==5.2.0
16
+ annotated-types==0.6.0
17
+ anyio==4.2.0
18
+ arxiv==2.1.0
19
+ asgiref==3.7.2
20
+ async-timeout==4.0.3
21
+ attrs==23.2.0
22
+ backoff==2.2.1
23
+ bcrypt==4.1.2
24
+ beautifulsoup4==4.12.3
25
+ boto3==1.34.44
26
+ botocore==1.34.44
27
+ build==1.0.3
28
+ cachetools==5.3.2
29
+ certifi==2024.2.2
30
+ chardet==5.2.0
31
+ charset-normalizer==3.3.2
32
+ chroma-hnswlib==0.7.3
33
+ chromadb==0.4.22
34
+ click==8.1.7
35
+ coloredlogs==15.0.1
36
+ contourpy==1.2.0
37
+ cycler==0.12.1
38
+ dataclasses-json==0.6.4
39
+ dataclasses-json-speakeasy==0.5.11
40
+ Deprecated==1.2.14
41
+ emoji==2.10.1
42
+ exceptiongroup==1.2.0
43
+ faiss-cpu==1.7.4
44
+ fastapi==0.109.2
45
+ feedparser==6.0.10
46
+ ffmpy==0.3.2
47
+ filelock==3.13.1
48
+ filetype==1.2.0
49
+ flatbuffers==23.5.26
50
+ fonttools==4.49.0
51
+ frozenlist==1.4.1
52
+ fsspec==2024.2.0
53
+ google-auth==2.28.0
54
+ google-search-results==2.4.2
55
+ googleapis-common-protos==1.62.0
56
+ gradio==3.48.0
57
+ gradio_client==0.6.1
58
+ greenlet==3.0.3
59
+ grpcio==1.60.1
60
+ h11==0.14.0
61
+ httpcore==1.0.3
62
+ httptools==0.6.1
63
+ httpx==0.26.0
64
+ huggingface-hub==0.20.3
65
+ humanfriendly==10.0
66
+ idna==3.6
67
+ importlib-metadata==6.11.0
68
+ importlib-resources==6.1.1
69
+ Jinja2==3.1.3
70
+ jmespath==1.0.1
71
+ joblib==1.3.2
72
+ jsonpatch==1.33
73
+ jsonpath-python==1.0.6
74
+ jsonpointer==2.4
75
+ jsonschema==4.21.1
76
+ jsonschema-specifications==2023.12.1
77
+ kiwisolver==1.4.5
78
+ kubernetes==29.0.0
79
+ langchain==0.1.7
80
+ langchain-community==0.0.20
81
+ langchain-core==0.1.23
82
+ langchainhub==0.1.14
83
+ langdetect==1.0.9
84
+ langsmith==0.0.87
85
+ lxml==5.1.0
86
+ MarkupSafe==2.1.5
87
+ marshmallow==3.20.2
88
+ matplotlib==3.8.3
89
+ mmh3==4.1.0
90
+ monotonic==1.6
91
+ mpmath==1.3.0
92
+ multidict==6.0.5
93
+ mypy-extensions==1.0.0
94
+ networkx==3.2.1
95
+ nltk==3.8.1
96
+ numpy==1.26.4
97
+ nvidia-cublas-cu12==12.1.3.1
98
+ nvidia-cuda-cupti-cu12==12.1.105
99
+ nvidia-cuda-nvrtc-cu12==12.1.105
100
+ nvidia-cuda-runtime-cu12==12.1.105
101
+ nvidia-cudnn-cu12==8.9.2.26
102
+ nvidia-cufft-cu12==11.0.2.54
103
+ nvidia-curand-cu12==10.3.2.106
104
+ nvidia-cusolver-cu12==11.4.5.107
105
+ nvidia-cusparse-cu12==12.1.0.106
106
+ nvidia-nccl-cu12==2.19.3
107
+ nvidia-nvjitlink-cu12==12.3.101
108
+ nvidia-nvtx-cu12==12.1.105
109
+ oauthlib==3.2.2
110
+ onnxruntime==1.17.0
111
+ opentelemetry-api==1.22.0
112
+ opentelemetry-exporter-otlp-proto-common==1.22.0
113
+ opentelemetry-exporter-otlp-proto-grpc==1.22.0
114
+ opentelemetry-instrumentation==0.43b0
115
+ opentelemetry-instrumentation-asgi==0.43b0
116
+ opentelemetry-instrumentation-fastapi==0.43b0
117
+ opentelemetry-proto==1.22.0
118
+ opentelemetry-sdk==1.22.0
119
+ opentelemetry-semantic-conventions==0.43b0
120
+ opentelemetry-util-http==0.43b0
121
+ orjson==3.9.14
122
+ overrides==7.7.0
123
+ packaging==23.2
124
+ pandas==2.2.0
125
+ pillow==10.2.0
126
+ posthog==3.4.1
127
+ protobuf==4.25.3
128
+ pulsar-client==3.4.0
129
+ pyasn1==0.5.1
130
+ pyasn1-modules==0.3.0
131
+ pydantic==2.6.1
132
+ pydantic_core==2.16.2
133
+ pydub==0.25.1
134
+ pyparsing==3.1.1
135
+ PyPika==0.48.9
136
+ pyproject_hooks==1.0.0
137
+ python-dateutil==2.8.2
138
+ python-dotenv==1.0.1
139
+ python-iso639==2024.2.7
140
+ python-magic==0.4.27
141
+ python-multipart==0.0.9
142
+ pytz==2024.1
143
+ PyYAML==6.0.1
144
+ rapidfuzz==3.6.1
145
+ referencing==0.33.0
146
+ regex==2023.12.25
147
+ requests==2.31.0
148
+ requests-oauthlib==1.3.1
149
+ rpds-py==0.18.0
150
+ rsa==4.9
151
+ s3transfer==0.10.0
152
+ safetensors==0.4.2
153
+ scikit-learn==1.4.1.post1
154
+ scipy==1.12.0
155
+ semantic-version==2.10.0
156
+ sentence-transformers==2.3.1
157
+ sentencepiece==0.1.99
158
+ sgmllib3k==1.0.0
159
+ six==1.16.0
160
+ sniffio==1.3.0
161
+ soupsieve==2.5
162
+ SQLAlchemy==2.0.27
163
+ starlette==0.36.3
164
+ sympy==1.12
165
+ tabulate==0.9.0
166
+ tenacity==8.2.3
167
+ threadpoolctl==3.3.0
168
+ tokenizers==0.15.2
169
+ tomli==2.0.1
170
+ toolz==0.12.1
171
+ torch==2.2.0
172
+ tqdm==4.66.2
173
+ transformers==4.37.2
174
+ triton==2.2.0
175
+ typer==0.9.0
176
+ types-requests==2.31.0.20240125
177
+ typing-inspect==0.9.0
178
+ typing_extensions==4.8.0
179
+ tzdata==2024.1
180
+ unstructured==0.12.4
181
+ unstructured-client==0.18.0
182
+ urllib3==2.0.7
183
+ uvicorn==0.27.1
184
+ uvloop==0.19.0
185
+ validators==0.22.0
186
+ watchfiles==0.21.0
187
+ websocket-client==1.7.0
188
+ websockets==11.0.3
189
+ wrapt==1.16.0
190
+ yarl==1.9.4
191
+ zipp==3.17.0