microstronger commited on
Commit
a718377
·
verified ·
1 Parent(s): 79213c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -134
app.py CHANGED
@@ -1,136 +1,154 @@
1
- import asyncio
2
  import json
3
- import logging
4
- from copy import deepcopy
5
- from dataclasses import asdict
6
- from typing import Dict, List, Union
7
-
8
- import janus
9
- from fastapi import FastAPI
10
- from fastapi.middleware.cors import CORSMiddleware
11
  from lagent.schema import AgentStatusCode
12
- from pydantic import BaseModel
13
- from sse_starlette.sse import EventSourceResponse
14
-
15
- from mindsearch.agent import init_agent
16
-
17
-
18
- def parse_arguments():
19
- import argparse
20
- parser = argparse.ArgumentParser(description='MindSearch API')
21
- parser.add_argument('--lang', default='cn', type=str, help='Language')
22
- parser.add_argument('--model_format',
23
- default='internlm_server',
24
- type=str,
25
- help='Model format')
26
- parser.add_argument('--search_engine',
27
- default='DuckDuckGoSearch',
28
- type=str,
29
- help='Search engine')
30
- return parser.parse_args()
31
-
32
-
33
- args = parse_arguments()
34
- app = FastAPI(docs_url='/')
35
-
36
- app.add_middleware(CORSMiddleware,
37
- allow_origins=['*'],
38
- allow_credentials=True,
39
- allow_methods=['*'],
40
- allow_headers=['*'])
41
-
42
-
43
- class GenerationParams(BaseModel):
44
- inputs: Union[str, List[Dict]]
45
- agent_cfg: Dict = dict()
46
-
47
-
48
- @app.post('/solve')
49
- async def run(request: GenerationParams):
50
-
51
- def convert_adjacency_to_tree(adjacency_input, root_name):
52
-
53
- def build_tree(node_name):
54
- node = {'name': node_name, 'children': []}
55
- if node_name in adjacency_input:
56
- for child in adjacency_input[node_name]:
57
- child_node = build_tree(child['name'])
58
- child_node['state'] = child['state']
59
- child_node['id'] = child['id']
60
- node['children'].append(child_node)
61
- return node
62
-
63
- return build_tree(root_name)
64
-
65
- async def generate():
66
- try:
67
- queue = janus.Queue()
68
- stop_event = asyncio.Event()
69
-
70
- # Wrapping a sync generator as an async generator using run_in_executor
71
- def sync_generator_wrapper():
72
- try:
73
- for response in agent.stream_chat(inputs):
74
- queue.sync_q.put(response)
75
- except Exception as e:
76
- logging.exception(
77
- f'Exception in sync_generator_wrapper: {e}')
78
- finally:
79
- # Notify async_generator_wrapper that the data generation is complete.
80
- queue.sync_q.put(None)
81
-
82
- async def async_generator_wrapper():
83
- loop = asyncio.get_event_loop()
84
- loop.run_in_executor(None, sync_generator_wrapper)
85
- while True:
86
- response = await queue.async_q.get()
87
- if response is None: # Ensure that all elements are consumed
88
- break
89
- yield response
90
- if not isinstance(
91
- response,
92
- tuple) and response.state == AgentStatusCode.END:
93
- break
94
- stop_event.set() # Inform sync_generator_wrapper to stop
95
-
96
- async for response in async_generator_wrapper():
97
- if isinstance(response, tuple):
98
- agent_return, node_name = response
99
- else:
100
- agent_return = response
101
- node_name = None
102
- origin_adj = deepcopy(agent_return.adjacency_list)
103
- adjacency_list = convert_adjacency_to_tree(
104
- agent_return.adjacency_list, 'root')
105
- assert adjacency_list[
106
- 'name'] == 'root' and 'children' in adjacency_list
107
- agent_return.adjacency_list = adjacency_list['children']
108
- agent_return = asdict(agent_return)
109
- agent_return['adj'] = origin_adj
110
- response_json = json.dumps(dict(response=agent_return,
111
- current_node=node_name),
112
- ensure_ascii=False)
113
- yield {'data': response_json}
114
- # yield f'data: {response_json}\n\n'
115
- except Exception as exc:
116
- msg = 'An error occurred while generating the response.'
117
- logging.exception(msg)
118
- response_json = json.dumps(
119
- dict(error=dict(msg=msg, details=str(exc))),
120
- ensure_ascii=False)
121
- yield {'data': response_json}
122
- # yield f'data: {response_json}\n\n'
123
- finally:
124
- await stop_event.wait(
125
- ) # Waiting for async_generator_wrapper to stop
126
- queue.close()
127
- await queue.wait_closed()
128
-
129
- inputs = request.inputs
130
- agent = init_agent(lang=args.lang, model_format=args.model_format,search_engine=args.search_engine)
131
- return EventSourceResponse(generate())
132
-
133
-
134
- if __name__ == '__main__':
135
- import uvicorn
136
- uvicorn.run(app, host='0.0.0.0', port=8002, log_level='info')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
+ import os
3
+
4
+ import gradio as gr
5
+ import requests
 
 
 
 
6
  from lagent.schema import AgentStatusCode
7
+
8
+ os.system("python -m mindsearch.app --lang cn --model_format internlm_silicon &")
9
+
10
+ PLANNER_HISTORY = []
11
+ SEARCHER_HISTORY = []
12
+
13
+
14
+ def rst_mem(history_planner: list, history_searcher: list):
15
+ '''
16
+ Reset the chatbot memory.
17
+ '''
18
+ history_planner = []
19
+ history_searcher = []
20
+ if PLANNER_HISTORY:
21
+ PLANNER_HISTORY.clear()
22
+ return history_planner, history_searcher
23
+
24
+
25
+ def format_response(gr_history, agent_return):
26
+ if agent_return['state'] in [
27
+ AgentStatusCode.STREAM_ING, AgentStatusCode.ANSWER_ING
28
+ ]:
29
+ gr_history[-1][1] = agent_return['response']
30
+ elif agent_return['state'] == AgentStatusCode.PLUGIN_START:
31
+ thought = gr_history[-1][1].split('```')[0]
32
+ if agent_return['response'].startswith('```'):
33
+ gr_history[-1][1] = thought + '\n' + agent_return['response']
34
+ elif agent_return['state'] == AgentStatusCode.PLUGIN_END:
35
+ thought = gr_history[-1][1].split('```')[0]
36
+ if isinstance(agent_return['response'], dict):
37
+ gr_history[-1][
38
+ 1] = thought + '\n' + f'```json\n{json.dumps(agent_return["response"], ensure_ascii=False, indent=4)}\n```' # noqa: E501
39
+ elif agent_return['state'] == AgentStatusCode.PLUGIN_RETURN:
40
+ assert agent_return['inner_steps'][-1]['role'] == 'environment'
41
+ item = agent_return['inner_steps'][-1]
42
+ gr_history.append([
43
+ None,
44
+ f"```json\n{json.dumps(item['content'], ensure_ascii=False, indent=4)}\n```"
45
+ ])
46
+ gr_history.append([None, ''])
47
+ return
48
+
49
+
50
+ def predict(history_planner, history_searcher):
51
+
52
+ def streaming(raw_response):
53
+ for chunk in raw_response.iter_lines(chunk_size=8192,
54
+ decode_unicode=False,
55
+ delimiter=b'\n'):
56
+ if chunk:
57
+ decoded = chunk.decode('utf-8')
58
+ if decoded == '\r':
59
+ continue
60
+ if decoded[:6] == 'data: ':
61
+ decoded = decoded[6:]
62
+ elif decoded.startswith(': ping - '):
63
+ continue
64
+ response = json.loads(decoded)
65
+ yield (response['response'], response['current_node'])
66
+
67
+ global PLANNER_HISTORY
68
+ PLANNER_HISTORY.append(dict(role='user', content=history_planner[-1][0]))
69
+ new_search_turn = True
70
+
71
+ url = 'http://localhost:8002/solve'
72
+ headers = {'Content-Type': 'application/json'}
73
+ data = {'inputs': PLANNER_HISTORY}
74
+ raw_response = requests.post(url,
75
+ headers=headers,
76
+ data=json.dumps(data),
77
+ timeout=20,
78
+ stream=True)
79
+
80
+ for resp in streaming(raw_response):
81
+ agent_return, node_name = resp
82
+ if node_name:
83
+ if node_name in ['root', 'response']:
84
+ continue
85
+ agent_return = agent_return['nodes'][node_name]['detail']
86
+ if new_search_turn:
87
+ history_searcher.append([agent_return['content'], ''])
88
+ new_search_turn = False
89
+ format_response(history_searcher, agent_return)
90
+ if agent_return['state'] == AgentStatusCode.END:
91
+ new_search_turn = True
92
+ yield history_planner, history_searcher
93
+ else:
94
+ new_search_turn = True
95
+ format_response(history_planner, agent_return)
96
+ if agent_return['state'] == AgentStatusCode.END:
97
+ PLANNER_HISTORY = agent_return['inner_steps']
98
+ yield history_planner, history_searcher
99
+ return history_planner, history_searcher
100
+
101
+
102
+ with gr.Blocks() as demo:
103
+ gr.HTML("""<h1 align="center">MindSearch Gradio Demo</h1>""")
104
+ gr.HTML("""<p style="text-align: center; font-family: Arial, sans-serif;">MindSearch is an open-source AI Search Engine Framework with Perplexity.ai Pro performance. You can deploy your own Perplexity.ai-style search engine using either closed-source LLMs (GPT, Claude) or open-source LLMs (InternLM2.5-7b-chat).</p>""")
105
+ gr.HTML("""
106
+ <div style="text-align: center; font-size: 16px;">
107
+ <a href="https://github.com/InternLM/MindSearch" style="margin-right: 15px; text-decoration: none; color: #4A90E2;">🔗 GitHub</a>
108
+ <a href="https://arxiv.org/abs/2407.20183" style="margin-right: 15px; text-decoration: none; color: #4A90E2;">📄 Arxiv</a>
109
+ <a href="https://huggingface.co/papers/2407.20183" style="margin-right: 15px; text-decoration: none; color: #4A90E2;">📚 Hugging Face Papers</a>
110
+ <a href="https://huggingface.co/spaces/internlm/MindSearch" style="text-decoration: none; color: #4A90E2;">🤗 Hugging Face Demo</a>
111
+ </div>
112
+ """)
113
+ with gr.Row():
114
+ with gr.Column(scale=10):
115
+ with gr.Row():
116
+ with gr.Column():
117
+ planner = gr.Chatbot(label='planner',
118
+ height=700,
119
+ show_label=True,
120
+ show_copy_button=True,
121
+ bubble_full_width=False,
122
+ render_markdown=True)
123
+ with gr.Column():
124
+ searcher = gr.Chatbot(label='searcher',
125
+ height=700,
126
+ show_label=True,
127
+ show_copy_button=True,
128
+ bubble_full_width=False,
129
+ render_markdown=True)
130
+ with gr.Row():
131
+ user_input = gr.Textbox(show_label=False,
132
+ placeholder='帮我搜索一下 InternLM 开源体系',
133
+ lines=5,
134
+ container=False)
135
+ with gr.Row():
136
+ with gr.Column(scale=2):
137
+ submitBtn = gr.Button('Submit')
138
+ with gr.Column(scale=1, min_width=20):
139
+ emptyBtn = gr.Button('Clear History')
140
+
141
+ def user(query, history):
142
+ return '', history + [[query, '']]
143
+
144
+ submitBtn.click(user, [user_input, planner], [user_input, planner],
145
+ queue=False).then(predict, [planner, searcher],
146
+ [planner, searcher])
147
+ emptyBtn.click(rst_mem, [planner, searcher], [planner, searcher],
148
+ queue=False)
149
+
150
+ demo.queue()
151
+ demo.launch(server_name='0.0.0.0',
152
+ server_port=7860,
153
+ inbrowser=True,
154
+ share=True)