Asaad Almutareb commited on
Commit
da70771
1 Parent(s): 9620371

cleaned unused code, moved prompt to separate file

Browse files
Files changed (1) hide show
  1. hf_mixtral_agent.py +10 -71
hf_mixtral_agent.py CHANGED
@@ -21,11 +21,9 @@ from innovation_pathfinder_ai.structured_tools.structured_tools import (
21
  from innovation_pathfinder_ai.source_container.container import (
22
  all_sources
23
  )
24
- from innovation_pathfinder_ai.utils import collect_urls
25
- # from langchain_community.chat_message_histories import ChatMessageHistory
26
- # from langchain_core.runnables.history import RunnableWithMessageHistory
27
 
28
- # message_history = ChatMessageHistory()
29
  config = load_dotenv(".env")
30
  HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
31
  GOOGLE_CSE_ID = os.getenv('GOOGLE_CSE_ID')
@@ -44,7 +42,7 @@ llm = HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
44
  )
45
 
46
 
47
- tools_all = [
48
  arxiv_search,
49
  wikipedia_search,
50
  google_search,
@@ -58,7 +56,9 @@ tools_papers = [
58
  ]
59
 
60
 
61
- prompt = hub.pull("hwchase17/react-json")
 
 
62
  prompt = prompt.partial(
63
  tools=render_text_description(tools),
64
  tool_names=", ".join([t.name for t in tools]),
@@ -71,6 +71,7 @@ agent = (
71
  {
72
  "input": lambda x: x["input"],
73
  "agent_scratchpad": lambda x: format_log_to_str(x["intermediate_steps"]),
 
74
  }
75
  | prompt
76
  | chat_model_with_stop
@@ -78,9 +79,9 @@ agent = (
78
  )
79
 
80
  # instantiate AgentExecutor
81
- agent_executor_all = AgentExecutor(
82
  agent=agent,
83
- tools=tools_all,
84
  verbose=True,
85
  max_iterations=6, # cap number of iterations
86
  #max_execution_time=60, # timout at 60 sec
@@ -97,66 +98,4 @@ agent_executor_noweb = AgentExecutor(
97
  #max_execution_time=60, # timout at 60 sec
98
  return_intermediate_steps=True,
99
  handle_parsing_errors=True,
100
- )
101
-
102
-
103
- if __name__ == "__main__":
104
-
105
- def add_text(history, text):
106
- history = history + [(text, None)]
107
- return history, ""
108
-
109
- def bot(history):
110
- response = infer(history[-1][0], history)
111
- sources = collect_urls(all_sources)
112
- src_list = '\n'.join(sources)
113
- response_w_sources = response['output']+"\n\n\n Sources: \n\n\n"+src_list
114
- intermediate_steps = response['intermediate_steps']
115
- history[-1][1] = response_w_sources
116
- return history
117
-
118
- def infer(question, history):
119
- query = question
120
- result = agent_executor_all.invoke(
121
- {
122
- "input": question,
123
- }
124
- )
125
- return result
126
-
127
- def vote(data: gr.LikeData):
128
- if data.liked:
129
- print("You upvoted this response: " + data.value)
130
- else:
131
- print("You downvoted this response: " + data.value)
132
-
133
-
134
- css="""
135
- #col-container {max-width: 700px; margin-left: auto; margin-right: auto;}
136
- """
137
-
138
- title = """
139
- <div style="text-align: center;max-width: 700px;">
140
- <p>Hello Dave, how can I help today?<br />
141
- </div>
142
- """
143
-
144
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
145
- with gr.Tab("Google|Wikipedia|Arxiv"):
146
- with gr.Column(elem_id="col-container"):
147
- gr.HTML(title)
148
- with gr.Row():
149
- question = gr.Textbox(label="Question", placeholder="Type your question and hit Enter ")
150
- chatbot = gr.Chatbot([], elem_id="chatbot")
151
- chatbot.like(vote, None, None)
152
- clear = gr.Button("Clear")
153
- question.submit(add_text, [chatbot, question], [chatbot, question], queue=False).then(
154
- bot, chatbot, chatbot
155
- )
156
- clear.click(lambda: None, None, chatbot, queue=False)
157
-
158
- demo.queue()
159
- demo.launch(debug=True)
160
-
161
-
162
- x = 0 # for debugging purposes
 
21
  from innovation_pathfinder_ai.source_container.container import (
22
  all_sources
23
  )
24
+ from langchain import PromptTemplate
25
+ from innovation_pathfinder_ai.templates.react_json_with_memory import template_system
 
26
 
 
27
  config = load_dotenv(".env")
28
  HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
29
  GOOGLE_CSE_ID = os.getenv('GOOGLE_CSE_ID')
 
42
  )
43
 
44
 
45
+ tools = [
46
  arxiv_search,
47
  wikipedia_search,
48
  google_search,
 
56
  ]
57
 
58
 
59
+ prompt = PromptTemplate.from_template(
60
+ template=template_system
61
+ )
62
  prompt = prompt.partial(
63
  tools=render_text_description(tools),
64
  tool_names=", ".join([t.name for t in tools]),
 
71
  {
72
  "input": lambda x: x["input"],
73
  "agent_scratchpad": lambda x: format_log_to_str(x["intermediate_steps"]),
74
+ "chat_history": lambda x: x["chat_history"],
75
  }
76
  | prompt
77
  | chat_model_with_stop
 
79
  )
80
 
81
  # instantiate AgentExecutor
82
+ agent_executor = AgentExecutor(
83
  agent=agent,
84
+ tools=tools,
85
  verbose=True,
86
  max_iterations=6, # cap number of iterations
87
  #max_execution_time=60, # timout at 60 sec
 
98
  #max_execution_time=60, # timout at 60 sec
99
  return_intermediate_steps=True,
100
  handle_parsing_errors=True,
101
+ )