Spaces:
Sleeping
Sleeping
from typing import Dict, Optional, Tuple, List, Any, Set | |
import re | |
import xml.etree.ElementTree as ET | |
from datetime import datetime | |
import json | |
import logging | |
# Setup logger | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
# Create console handler if needed | |
if not logger.handlers: | |
ch = logging.StreamHandler() | |
ch.setLevel(logging.INFO) | |
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
ch.setFormatter(formatter) | |
logger.addHandler(ch) | |
class StreamingFormatter: | |
def __init__(self): | |
self.processed_events = set() | |
self.current_tool_outputs = [] | |
self.current_citations = [] | |
self.current_metadata = {} | |
self.current_message_id = None | |
self.current_message_buffer = "" | |
def reset(self): | |
"""Reset the formatter state""" | |
self.processed_events.clear() | |
self.current_tool_outputs.clear() | |
self.current_citations.clear() | |
self.current_metadata.clear() | |
self.current_message_id = None | |
self.current_message_buffer = "" | |
def append_to_buffer(self, text: str): | |
"""Append text to the current message buffer""" | |
self.current_message_buffer += text | |
def get_and_clear_buffer(self) -> str: | |
"""Get the current buffer content and clear it""" | |
content = self.current_message_buffer | |
self.current_message_buffer = "" | |
return content | |
class ToolType: | |
DUCKDUCKGO = "duckduckgo_search" | |
REDDIT_NEWS = "reddit_x_gnews_newswire_crunchbase" | |
PUBMED = "pubmed_search" | |
CENSUS = "get_census_data" | |
HEATMAP = "heatmap_code" | |
MERMAID = "mermaid_diagram" | |
WISQARS = "wisqars" | |
WONDER = "wonder" | |
NCHS = "nchs" | |
ONESTEP = "onestep" | |
DQS = "dqs_nhis_adult_summary_health_statistics" | |
class ResponseFormatter: | |
_instance = None | |
def __new__(cls): | |
if cls._instance is None: | |
cls._instance = super(ResponseFormatter, cls).__new__(cls) | |
cls._instance.streaming_state = StreamingFormatter() | |
cls._instance.logger = logger | |
return cls._instance | |
def format_thought( | |
self, | |
thought: str, | |
observation: str, | |
tool_outputs: List[Dict] = None, | |
event_id: str = None, | |
message_id: str = None | |
) -> Optional[Tuple[str, str]]: | |
"""Format thought and tool outputs as XML""" | |
root = ET.Element("agent_response") | |
if thought: | |
thought_elem = ET.SubElement(root, "thought") | |
thought_elem.text = thought | |
if observation: | |
obs_elem = ET.SubElement(root, "observation") | |
obs_elem.text = observation | |
if tool_outputs: | |
tools_elem = ET.SubElement(root, "tool_outputs") | |
for output in tool_outputs: | |
tool_elem = ET.SubElement(tools_elem, "tool_output") | |
tool_elem.attrib["type"] = output.get("type", "") | |
tool_elem.text = output.get("content", "") | |
xml_output = ET.tostring(root, encoding='unicode') | |
return thought, xml_output | |
def format_message( | |
self, | |
message: str, | |
event_id: str = None, | |
message_id: str = None | |
) -> Optional[Tuple[str, str]]: | |
"""Format message as XML for frontend""" | |
if not message: | |
return None | |
root = ET.Element("agent_response") | |
msg_elem = ET.SubElement(root, "message") | |
msg_elem.text = message | |
xml_output = ET.tostring(root, encoding='unicode') | |
return message, xml_output | |
def format_error( | |
self, | |
error: str, | |
event_id: str = None, | |
message_id: str = None | |
) -> Optional[Tuple[str, str]]: | |
"""Format error message for both terminal and XML output""" | |
# Skip if already processed | |
if event_id and event_id in self.streaming_state.processed_events: | |
return None | |
# Handle message state | |
if message_id != self.streaming_state.current_message_id: | |
self.streaming_state.reset() | |
self.streaming_state.current_message_id = message_id | |
# Skip empty errors | |
if not error: | |
return None | |
# Terminal format | |
terminal_output = f"Error: {error}" | |
# XML format | |
root = ET.Element("agent_response") | |
error_elem = ET.SubElement(root, "error") | |
error_elem.text = error | |
xml_output = ET.tostring(root, encoding='unicode') | |
# Track processed event | |
if event_id: | |
self.streaming_state.processed_events.add(event_id) | |
return terminal_output, xml_output | |
def _clean_markdown(text: str) -> str: | |
"""Clean markdown formatting from text""" | |
text = re.sub(r'```.*?```', '', text, flags=re.DOTALL) | |
text = re.sub(r'[*_`#]', '', text) | |
return re.sub(r'\n{3,}', '\n\n', text.strip()) |