from typing import Dict, Optional, Tuple, List, Any import re import xml.etree.ElementTree as ET from datetime import datetime import json import logging # Setup logger logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # Create console handler if needed if not logger.handlers: ch = logging.StreamHandler() ch.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') ch.setFormatter(formatter) logger.addHandler(ch) class ToolType: DUCKDUCKGO = "duckduckgo_search" REDDIT_NEWS = "reddit_x_gnews_newswire_crunchbase" PUBMED = "pubmed_search" CENSUS = "get_census_data" HEATMAP = "heatmap_code" MERMAID = "mermaid_diagram" WISQARS = "wisqars" WONDER = "wonder" NCHS = "nchs" ONESTEP = "onestep" DQS = "dqs_nhis_adult_summary_health_statistics" class ResponseFormatter: @staticmethod def format_thought( thought: str, observation: str, citations: List[Dict] = None, metadata: Dict = None, tool_outputs: List[Dict] = None ) -> Tuple[str, str]: """Format agent thought for both terminal and XML output""" # Terminal format terminal_output = { "type": "agent_thought", "content": thought, "metadata": metadata or {} } if tool_outputs: terminal_output["tool_outputs"] = tool_outputs # XML format root = ET.Element("agent_response") thought_elem = ET.SubElement(root, "thought") thought_elem.text = thought if observation: obs_elem = ET.SubElement(root, "observation") obs_elem.text = observation if tool_outputs: tools_elem = ET.SubElement(root, "tool_outputs") for tool_output in tool_outputs: tool_elem = ET.SubElement(tools_elem, "tool_output") tool_elem.attrib["type"] = tool_output.get("type", "") tool_elem.text = tool_output.get("content", "") if citations: cites_elem = ET.SubElement(root, "citations") for citation in citations: cite_elem = ET.SubElement(cites_elem, "citation") for key, value in citation.items(): cite_elem.attrib[key] = str(value) xml_output = ET.tostring(root, encoding='unicode') return json.dumps(terminal_output), xml_output @staticmethod def _create_tool_element(parent: ET.Element, tool_name: str, tool_data: Dict) -> ET.Element: """Create XML element for specific tool type with appropriate structure""" tool_elem = ET.SubElement(parent, "tool") tool_elem.set("name", tool_name) # Handle different tool types if tool_name == ToolType.CENSUS: ResponseFormatter._format_census_data(tool_elem, tool_data) elif tool_name == ToolType.MERMAID: ResponseFormatter._format_mermaid_data(tool_elem, tool_data) elif tool_name in [ToolType.WISQARS, ToolType.WONDER, ToolType.NCHS]: ResponseFormatter._format_health_data(tool_elem, tool_data) else: # Generic tool output format content_elem = ET.SubElement(tool_elem, "content") content_elem.text = ResponseFormatter._clean_markdown(str(tool_data)) return tool_elem @staticmethod def _format_census_data(tool_elem: ET.Element, data: Dict) -> None: """Format census data with specific structure""" try: # Extract census tract data tracts_elem = ET.SubElement(tool_elem, "census_tracts") # Parse the llm_result to extract structured data if "llm_result" in data: result = json.loads(data["llm_result"]) for tract_data in result.get("tracts", []): tract_elem = ET.SubElement(tracts_elem, "tract") tract_elem.set("id", str(tract_data.get("tract", ""))) # Add tract details for key, value in tract_data.items(): if key != "tract": detail_elem = ET.SubElement(tract_elem, key.replace("_", "")) detail_elem.text = str(value) except: # Fallback to simple format if parsing fails content_elem = ET.SubElement(tool_elem, "content") content_elem.text = ResponseFormatter._clean_markdown(str(data)) @staticmethod def _format_mermaid_data(tool_elem: ET.Element, data: Dict) -> None: """Format mermaid diagram data with improved error handling""" try: diagram_elem = ET.SubElement(tool_elem, "diagram") # Extract content from data content = "" if isinstance(data, dict): content = data.get("content", data.get("mermaid_diagram", "")) elif isinstance(data, str): content = data # Clean any remaining markdown/JSON formatting content = re.sub(r'```mermaid\s*|\s*```', '', content) content = re.sub(r'tool response:.*?{', '{', content) content = re.sub(r'}\s*\.$', '}', content) # Set cleaned content diagram_elem.text = content.strip() except Exception as e: logger.error(f"Error formatting mermaid data: {e}") content_elem = ET.SubElement(tool_elem, "content") content_elem.text = "Error formatting diagram" @staticmethod def _format_health_data(tool_elem: ET.Element, data: Dict) -> None: """Format health-related data from WISQARS, WONDER, etc.""" try: if isinstance(data, dict): for key, value in data.items(): category_elem = ET.SubElement(tool_elem, key.replace("_", "")) if isinstance(value, dict): for sub_key, sub_value in value.items(): sub_elem = ET.SubElement(category_elem, sub_key.replace("_", "")) sub_elem.text = str(sub_value) else: category_elem.text = str(value) except: content_elem = ET.SubElement(tool_elem, "content") content_elem.text = ResponseFormatter._clean_markdown(str(data)) @staticmethod def _extract_tool_outputs(observation: str) -> Dict[str, Any]: """Extract and clean tool outputs from observation""" tool_outputs = {} try: if isinstance(observation, str): data = json.loads(observation) for key, value in data.items(): if isinstance(value, str) and "llm_result" in value: try: tool_result = json.loads(value) tool_outputs[key] = tool_result except: tool_outputs[key] = value except: pass return tool_outputs @staticmethod def format_message(message: str) -> Tuple[str, str]: """Format agent message for both terminal and XML output""" # Terminal format terminal_output = message.strip() # XML format root = ET.Element("agent_response") msg_elem = ET.SubElement(root, "message") msg_elem.text = message.strip() xml_output = ET.tostring(root, encoding='unicode') return terminal_output, xml_output @staticmethod def format_error(error: str) -> Tuple[str, str]: """Format error message for both terminal and XML output""" # Terminal format terminal_output = f"Error: {error}" # XML format root = ET.Element("agent_response") error_elem = ET.SubElement(root, "error") error_elem.text = error xml_output = ET.tostring(root, encoding='unicode') return terminal_output, xml_output @staticmethod def _clean_markdown(text: str) -> str: """Clean markdown formatting from text""" text = re.sub(r'```.*?```', '', text, flags=re.DOTALL) text = re.sub(r'[*_`#]', '', text) return re.sub(r'\n{3,}', '\n\n', text.strip())