OmAgent / webpage.py
韩宇
init req
51e209e
import html
import json
import os
import queue
import shutil
import sys
import threading
import uuid
from pathlib import Path
from time import sleep
os.environ['GRADIO_TEMP_DIR'] = os.getcwd()
video_root_path = os.path.join(os.getcwd(), 'video_root')
os.makedirs(video_root_path, exist_ok=True)
from omagent_core.clients.devices.app.callback import AppCallback
from omagent_core.clients.devices.app.input import AppInput
from omagent_core.clients.devices.app.schemas import ContentStatus, MessageType
from omagent_core.engine.automator.task_handler import TaskHandler
from omagent_core.engine.http.models.workflow_status import terminal_status
from omagent_core.engine.workflow.conductor_workflow import ConductorWorkflow
from omagent_core.services.connectors.redis import RedisConnector
from omagent_core.utils.build import build_from_file
from omagent_core.utils.container import container
from omagent_core.utils.logger import logging
from omagent_core.utils.registry import registry
registry.import_module()
container.register_connector(name="redis_stream_client", connector=RedisConnector)
# container.register_stm(stm='RedisSTM')
container.register_callback(callback=AppCallback)
container.register_input(input=AppInput)
import gradio as gr
class WebpageClient:
def __init__(
self,
interactor: ConductorWorkflow = None,
processor: ConductorWorkflow = None,
config_path: str = "./config",
workers: list = [],
) -> None:
self._interactor = interactor
self._processor = processor
self._config_path = config_path
self._workers = workers
self._workflow_instance_id = None
self._worker_config = build_from_file(self._config_path)
self._task_to_domain = {}
self._incomplete_message = ""
self._custom_css = """
#OmAgent {
height: 100vh !important;
max-height: calc(100vh - 190px) !important;
overflow-y: auto;
}
.running-message {
margin: 0;
padding: 2px 4px;
white-space: pre-wrap;
word-wrap: break-word;
font-family: inherit;
}
/* Remove the background and border of the message box */
.message-wrap {
background: none !important;
border: none !important;
padding: 0 !important;
margin: 0 !important;
}
/* Remove the bubble style of the running message */
.message:has(.running-message) {
background: none !important;
border: none !important;
padding: 0 !important;
box-shadow: none !important;
}
"""
self.workflow_instance_id = str(uuid.uuid4())
self.processor_instance_id = str(uuid.uuid4())
worker_config = build_from_file(self._config_path)
self.initialization(workers, worker_config)
def initialization(self, workers, worker_config):
self.workers = {}
for worker in workers:
worker.workflow_instance_id = self.workflow_instance_id
self.workers[type(worker).__name__] = worker
for config in worker_config:
worker_cls = registry.get_worker(config['name'])
worker = worker_cls(**config)
worker.workflow_instance_id = self.workflow_instance_id
self.workers[config['name']] = worker
def gradio_app(self):
with gr.Blocks() as demo:
def load_local_video() -> dict:
result = {}
for root, _, files in os.walk(video_root_path):
for file in filter(lambda x: x.split('.')[-1].lower() in (
'mp4', 'avi', 'mov', 'wmv', 'flv', 'mkv', 'webm', 'm4v'), files):
file_obs_path = os.path.join(root, file)
result[Path(file_obs_path).name] = file_obs_path
return result
video_dict = load_local_video()
current_video = None
state = gr.State(value={
'video_dict': video_dict,
'current_video': current_video
})
with gr.Row():
with gr.Column():
with gr.Column():
def display_video_map(video_title):
# change display video
video_path = state.value.get('video_dict', {}).get(video_title)
exception_queue = queue.Queue()
workflow_input = {'video_path': video_path}
processor_result = None
def run_workflow(workflow_input):
nonlocal processor_result
try:
processor_result = self._processor.start_workflow_with_input(
workflow_input=workflow_input, workers=self.workers
)
except Exception as e:
exception_queue.put(e) # add exception to queue
logging.error(f"Error starting workflow: {e}")
raise e
# workflow_thread = threading.Thread(target=run_workflow, args=(workflow_input,))
# workflow_thread.start()
run_workflow(workflow_input)
processor_workflow_instance_id = self.processor_instance_id
while True:
status = self._processor.get_workflow(
workflow_id=processor_workflow_instance_id).status
if status in terminal_status:
break
sleep(1)
state.value['video_dict'] = load_local_video()
state.value.update(current_video=video_path)
state.value.update(processor_result=processor_result)
state.value.update(processor_workflow_instance_id=processor_workflow_instance_id)
return video_path, state
select_video = gr.Dropdown(
state.value['video_dict'].keys(),
value=None
)
display_video = gr.Video(
state.value['current_video'],
)
select_video.change(
fn=display_video_map,
inputs=[select_video],
outputs=[display_video, state]
)
with gr.Column():
chatbot = gr.Chatbot(
type="messages",
)
chat_input = gr.Textbox(
interactive=True,
placeholder="Enter message...",
show_label=False,
)
chat_msg = chat_input.submit(
self.add_message,
[chatbot, chat_input, state],
[chatbot, chat_input]
)
bot_msg = chat_msg.then(
self.bot, (chatbot, state), chatbot, api_name="bot_response"
)
bot_msg.then(
lambda: gr.Textbox(interactive=True), None, [chat_input]
)
demo.launch(
max_file_size='1gb'
)
def start_interactor(self):
try:
self.gradio_app()
except KeyboardInterrupt:
logging.info("\nDetected Ctrl+C, stopping workflow...")
if self._workflow_instance_id is not None:
self._interactor._executor.terminate(
workflow_id=self._workflow_instance_id
)
raise
def stop_interactor(self):
# self._task_handler_interactor.stop_processes()
print("stop_interactor")
sys.exit(0)
def start_processor(self):
self._task_handler_processor = TaskHandler(
worker_config=self._worker_config, workers=self._workers, task_to_domain=self._task_to_domain
)
self._task_handler_processor.start_processes()
try:
with gr.Blocks(title="OmAgent", css=self._custom_css) as chat_interface:
chatbot = gr.Chatbot(
elem_id="OmAgent",
bubble_full_width=False,
type="messages",
height="100%",
)
chat_input = gr.MultimodalTextbox(
interactive=True,
file_count="multiple",
placeholder="Enter message or upload file...",
show_label=False,
)
chat_msg = chat_input.submit(
self.add_processor_message,
[chatbot, chat_input],
[chatbot, chat_input],
)
bot_msg = chat_msg.then(
self.processor_bot, chatbot, chatbot, api_name="bot_response"
)
bot_msg.then(
lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input]
)
chat_interface.launch(server_port=7861)
except KeyboardInterrupt:
logging.info("\nDetected Ctrl+C, stopping workflow...")
if self._workflow_instance_id is not None:
self._processor._executor.terminate(
workflow_id=self._workflow_instance_id
)
raise
def stop_processor(self):
self._task_handler_processor.stop_processes()
def add_message(self, history, message, state):
if isinstance(state, gr.State):
if state.value.get('current_video') is None:
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": 'Please select a video'})
return history, gr.Textbox(value=None, interactive=False)
else:
if not state:
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": 'Please reselect the video'})
return history, gr.Textbox(value=None, interactive=False)
if state.get('current_video') is None:
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": 'Please select a video'})
return history, gr.Textbox(value=None, interactive=False)
if self._workflow_instance_id is None:
workflow_input = {
'question': message,
"video_md5": state.value.get('processor_result', {}).get("video_md5"),
"video_path": state.value.get('processor_result', {}).get("video_path"),
"instance_id": state.value.get('processor_result', {}).get("instance_id"),
"processor_workflow_instance_id": state.value.get("processor_workflow_instance_id")
}
exception_queue = queue.Queue()
def run_workflow(workflow_input):
try:
self._interactor.start_workflow_with_input(
workflow_input=workflow_input, workers=self.workers
)
except Exception as e:
exception_queue.put(e) # add exception to queue
logging.error(f"Error starting workflow: {e}")
raise e
workflow_thread = threading.Thread(target=run_workflow, args=(workflow_input,),daemon=True)
workflow_thread.start()
self._workflow_instance_id = self.workflow_instance_id
contents = []
history.append({"role": "user", "content": message})
contents.append({"data": message, "type": "text"})
result = {
"agent_id": self._workflow_instance_id,
"messages": [{"role": "user", "content": contents}],
"kwargs": {},
}
container.get_connector("redis_stream_client")._client.xadd(
f"{self._workflow_instance_id}_input",
{"payload": json.dumps(result, ensure_ascii=False)},
)
return history, gr.Textbox(value=None, interactive=False)
def add_processor_message(self, history, message):
if self._workflow_instance_id is None:
self._workflow_instance_id = self._processor.start_workflow_with_input(
workflow_input={}, task_to_domain=self._task_to_domain
)
image_items = []
for idx, x in enumerate(message["files"]):
history.append({"role": "user", "content": {"path": x}})
image_items.append(
{"type": "image_url", "resource_id": str(idx), "data": str(x)}
)
result = {"content": image_items}
container.get_connector("redis_stream_client")._client.xadd(
f"image_process", {"payload": json.dumps(result, ensure_ascii=False)}
)
return history, gr.MultimodalTextbox(value=None, interactive=False)
def bot(self, history, state):
if isinstance(state, gr.State):
if state.value.get('current_video') is None:
yield history
return
else:
if state.get('current_video') is None:
yield history
return
stream_name = f"{self._workflow_instance_id}_output"
consumer_name = f"{self._workflow_instance_id}_agent" # consumer name
group_name = "omappagent" # replace with your consumer group name
running_stream_name = f"{self._workflow_instance_id}_running"
self._check_redis_stream_exist(stream_name, group_name)
self._check_redis_stream_exist(running_stream_name, group_name)
while True:
# read running stream
running_messages = self._get_redis_stream_message(
group_name, consumer_name, running_stream_name
)
for stream, message_list in running_messages:
for message_id, message in message_list:
payload_data = self._get_message_payload(message)
if payload_data is None:
continue
progress = html.escape(payload_data.get("progress", ""))
message = html.escape(payload_data.get("message", ""))
formatted_message = (
f'<pre class="running-message">{progress}: {message}</pre>'
)
history.append({"role": "assistant", "content": formatted_message})
yield history
container.get_connector("redis_stream_client")._client.xack(
running_stream_name, group_name, message_id
)
# read output stream
messages = self._get_redis_stream_message(
group_name, consumer_name, stream_name
)
finish_flag = False
for stream, message_list in messages:
for message_id, message in message_list:
incomplete_flag = False
payload_data = self._get_message_payload(message)
if payload_data is None:
continue
if payload_data["content_status"] == ContentStatus.INCOMPLETE.value:
incomplete_flag = True
message_item = payload_data["message"]
if message_item["type"] == MessageType.IMAGE_URL.value:
history.append(
{
"role": "assistant",
"content": {"path": message_item["content"]},
}
)
else:
if incomplete_flag:
self._incomplete_message = (
self._incomplete_message + message_item["content"]
)
if history and history[-1]["role"] == "assistant":
history[-1]["content"] = self._incomplete_message
else:
history.append(
{
"role": "assistant",
"content": self._incomplete_message,
}
)
else:
if self._incomplete_message != "":
self._incomplete_message = (
self._incomplete_message + message_item["content"]
)
if history and history[-1]["role"] == "assistant":
history[-1]["content"] = self._incomplete_message
else:
history.append(
{
"role": "assistant",
"content": self._incomplete_message,
}
)
self._incomplete_message = ""
else:
history.append(
{
"role": "assistant",
"content": message_item["content"],
}
)
yield history
container.get_connector("redis_stream_client")._client.xack(
stream_name, group_name, message_id
)
# check finish flag
if (
"interaction_type" in payload_data
and payload_data["interaction_type"] == 1
):
finish_flag = True
if (
"content_status" in payload_data
and payload_data["content_status"]
== ContentStatus.END_ANSWER.value
):
self._workflow_instance_id = None
finish_flag = True
if finish_flag:
break
sleep(0.01)
def processor_bot(self, history: list):
history.append({"role": "assistant", "content": f"processing..."})
yield history
while True:
status = self._processor.get_workflow(
workflow_id=self._workflow_instance_id
).status
if status in terminal_status:
history.append({"role": "assistant", "content": f"completed"})
yield history
self._workflow_instance_id = None
break
sleep(0.01)
def _get_redis_stream_message(
self, group_name: str, consumer_name: str, stream_name: str
):
messages = container.get_connector("redis_stream_client")._client.xreadgroup(
group_name, consumer_name, {stream_name: ">"}, count=1
)
messages = [
(
stream,
[
(
message_id,
{
k.decode("utf-8"): v.decode("utf-8")
for k, v in message.items()
},
)
for message_id, message in message_list
],
)
for stream, message_list in messages
]
return messages
def _check_redis_stream_exist(self, stream_name: str, group_name: str):
try:
container.get_connector("redis_stream_client")._client.xgroup_create(
stream_name, group_name, id="0", mkstream=True
)
except Exception as e:
logging.debug(f"Consumer group may already exist: {e}")
def _get_message_payload(self, message: dict):
logging.info(f"Received running message: {message}")
payload = message.get("payload")
# check payload data
if not payload:
logging.error("Payload is empty")
return None
try:
payload_data = json.loads(payload)
except json.JSONDecodeError as e:
logging.error(f"Payload is not a valid JSON: {e}")
return None
return payload_data