from typing import Dict, Optional, Tuple, List, Any, Set import re import xml.etree.ElementTree as ET from datetime import datetime import json import logging # Setup logger logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # Create console handler if needed if not logger.handlers: ch = logging.StreamHandler() ch.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') ch.setFormatter(formatter) logger.addHandler(ch) class StreamingFormatter: def __init__(self): self.processed_events = set() self.current_tool_outputs = [] self.current_citations = [] self.current_metadata = {} self.current_message_id = None self.current_message_buffer = "" def reset(self): """Reset the formatter state""" self.processed_events.clear() self.current_tool_outputs.clear() self.current_citations.clear() self.current_metadata.clear() self.current_message_id = None self.current_message_buffer = "" def append_to_buffer(self, text: str): """Append text to the current message buffer""" self.current_message_buffer += text def get_and_clear_buffer(self) -> str: """Get the current buffer content and clear it""" content = self.current_message_buffer self.current_message_buffer = "" return content class ToolType: DUCKDUCKGO = "duckduckgo_search" REDDIT_NEWS = "reddit_x_gnews_newswire_crunchbase" PUBMED = "pubmed_search" CENSUS = "get_census_data" HEATMAP = "heatmap_code" MERMAID = "mermaid_diagram" WISQARS = "wisqars" WONDER = "wonder" NCHS = "nchs" ONESTEP = "onestep" DQS = "dqs_nhis_adult_summary_health_statistics" class ResponseFormatter: _instance = None def __new__(cls): if cls._instance is None: cls._instance = super(ResponseFormatter, cls).__new__(cls) cls._instance.streaming_state = StreamingFormatter() cls._instance.logger = logger return cls._instance def format_thought( self, thought: str, observation: str, tool_outputs: List[Dict] = None, event_id: str = None, message_id: str = None ) -> Optional[Tuple[str, str]]: """Format thought and tool outputs as XML""" root = ET.Element("agent_response") if thought: thought_elem = ET.SubElement(root, "thought") thought_elem.text = thought if observation: obs_elem = ET.SubElement(root, "observation") obs_elem.text = observation if tool_outputs: tools_elem = ET.SubElement(root, "tool_outputs") for output in tool_outputs: tool_elem = ET.SubElement(tools_elem, "tool_output") tool_elem.attrib["type"] = output.get("type", "") tool_elem.text = output.get("content", "") xml_output = ET.tostring(root, encoding='unicode') return thought, xml_output def format_message( self, message: str, event_id: str = None, message_id: str = None ) -> Optional[Tuple[str, str]]: """Format message as XML for frontend""" if not message: return None root = ET.Element("agent_response") msg_elem = ET.SubElement(root, "message") msg_elem.text = message xml_output = ET.tostring(root, encoding='unicode') return message, xml_output def format_error( self, error: str, event_id: str = None, message_id: str = None ) -> Optional[Tuple[str, str]]: """Format error message for both terminal and XML output""" # Skip if already processed if event_id and event_id in self.streaming_state.processed_events: return None # Handle message state if message_id != self.streaming_state.current_message_id: self.streaming_state.reset() self.streaming_state.current_message_id = message_id # Skip empty errors if not error: return None # Terminal format terminal_output = f"Error: {error}" # XML format root = ET.Element("agent_response") error_elem = ET.SubElement(root, "error") error_elem.text = error xml_output = ET.tostring(root, encoding='unicode') # Track processed event if event_id: self.streaming_state.processed_events.add(event_id) return terminal_output, xml_output @staticmethod def _clean_markdown(text: str) -> str: """Clean markdown formatting from text""" text = re.sub(r'```.*?```', '', text, flags=re.DOTALL) text = re.sub(r'[*_`#]', '', text) return re.sub(r'\n{3,}', '\n\n', text.strip())