Spaces:
Sleeping
Sleeping
from typing import Dict, Optional, Tuple, List, Any | |
import re | |
import xml.etree.ElementTree as ET | |
from datetime import datetime | |
import json | |
import logging | |
# Setup logger | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
# Create console handler if needed | |
if not logger.handlers: | |
ch = logging.StreamHandler() | |
ch.setLevel(logging.INFO) | |
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
ch.setFormatter(formatter) | |
logger.addHandler(ch) | |
class ToolType: | |
DUCKDUCKGO = "duckduckgo_search" | |
REDDIT_NEWS = "reddit_x_gnews_newswire_crunchbase" | |
PUBMED = "pubmed_search" | |
CENSUS = "get_census_data" | |
HEATMAP = "heatmap_code" | |
MERMAID = "mermaid_diagram" | |
WISQARS = "wisqars" | |
WONDER = "wonder" | |
NCHS = "nchs" | |
ONESTEP = "onestep" | |
DQS = "dqs_nhis_adult_summary_health_statistics" | |
class ResponseFormatter: | |
def format_thought( | |
thought: str, | |
observation: str, | |
citations: List[Dict] = None, | |
metadata: Dict = None, | |
tool_outputs: List[Dict] = None | |
) -> Tuple[str, str]: | |
"""Format agent thought for both terminal and XML output""" | |
# Terminal format | |
terminal_output = { | |
"type": "agent_thought", | |
"content": thought, | |
"metadata": metadata or {} | |
} | |
if tool_outputs: | |
terminal_output["tool_outputs"] = tool_outputs | |
# XML format | |
root = ET.Element("agent_response") | |
thought_elem = ET.SubElement(root, "thought") | |
thought_elem.text = thought | |
if observation: | |
obs_elem = ET.SubElement(root, "observation") | |
obs_elem.text = observation | |
if tool_outputs: | |
tools_elem = ET.SubElement(root, "tool_outputs") | |
for tool_output in tool_outputs: | |
tool_elem = ET.SubElement(tools_elem, "tool_output") | |
tool_elem.attrib["type"] = tool_output.get("type", "") | |
tool_elem.text = tool_output.get("content", "") | |
if citations: | |
cites_elem = ET.SubElement(root, "citations") | |
for citation in citations: | |
cite_elem = ET.SubElement(cites_elem, "citation") | |
for key, value in citation.items(): | |
cite_elem.attrib[key] = str(value) | |
xml_output = ET.tostring(root, encoding='unicode') | |
return json.dumps(terminal_output), xml_output | |
def _create_tool_element(parent: ET.Element, tool_name: str, tool_data: Dict) -> ET.Element: | |
"""Create XML element for specific tool type with appropriate structure""" | |
tool_elem = ET.SubElement(parent, "tool") | |
tool_elem.set("name", tool_name) | |
# Handle different tool types | |
if tool_name == ToolType.CENSUS: | |
ResponseFormatter._format_census_data(tool_elem, tool_data) | |
elif tool_name == ToolType.MERMAID: | |
ResponseFormatter._format_mermaid_data(tool_elem, tool_data) | |
elif tool_name in [ToolType.WISQARS, ToolType.WONDER, ToolType.NCHS]: | |
ResponseFormatter._format_health_data(tool_elem, tool_data) | |
else: | |
# Generic tool output format | |
content_elem = ET.SubElement(tool_elem, "content") | |
content_elem.text = ResponseFormatter._clean_markdown(str(tool_data)) | |
return tool_elem | |
def _format_census_data(tool_elem: ET.Element, data: Dict) -> None: | |
"""Format census data with specific structure""" | |
try: | |
# Extract census tract data | |
tracts_elem = ET.SubElement(tool_elem, "census_tracts") | |
# Parse the llm_result to extract structured data | |
if "llm_result" in data: | |
result = json.loads(data["llm_result"]) | |
for tract_data in result.get("tracts", []): | |
tract_elem = ET.SubElement(tracts_elem, "tract") | |
tract_elem.set("id", str(tract_data.get("tract", ""))) | |
# Add tract details | |
for key, value in tract_data.items(): | |
if key != "tract": | |
detail_elem = ET.SubElement(tract_elem, key.replace("_", "")) | |
detail_elem.text = str(value) | |
except: | |
# Fallback to simple format if parsing fails | |
content_elem = ET.SubElement(tool_elem, "content") | |
content_elem.text = ResponseFormatter._clean_markdown(str(data)) | |
def _format_mermaid_data(tool_elem: ET.Element, data: Dict) -> None: | |
"""Format mermaid diagram data with improved error handling""" | |
try: | |
diagram_elem = ET.SubElement(tool_elem, "diagram") | |
# Extract content from data | |
content = "" | |
if isinstance(data, dict): | |
content = data.get("content", data.get("mermaid_diagram", "")) | |
elif isinstance(data, str): | |
content = data | |
# Clean any remaining markdown/JSON formatting | |
content = re.sub(r'```mermaid\s*|\s*```', '', content) | |
content = re.sub(r'tool response:.*?{', '{', content) | |
content = re.sub(r'}\s*\.$', '}', content) | |
# Set cleaned content | |
diagram_elem.text = content.strip() | |
except Exception as e: | |
logger.error(f"Error formatting mermaid data: {e}") | |
content_elem = ET.SubElement(tool_elem, "content") | |
content_elem.text = "Error formatting diagram" | |
def _format_health_data(tool_elem: ET.Element, data: Dict) -> None: | |
"""Format health-related data from WISQARS, WONDER, etc.""" | |
try: | |
if isinstance(data, dict): | |
for key, value in data.items(): | |
category_elem = ET.SubElement(tool_elem, key.replace("_", "")) | |
if isinstance(value, dict): | |
for sub_key, sub_value in value.items(): | |
sub_elem = ET.SubElement(category_elem, sub_key.replace("_", "")) | |
sub_elem.text = str(sub_value) | |
else: | |
category_elem.text = str(value) | |
except: | |
content_elem = ET.SubElement(tool_elem, "content") | |
content_elem.text = ResponseFormatter._clean_markdown(str(data)) | |
def _extract_tool_outputs(observation: str) -> Dict[str, Any]: | |
"""Extract and clean tool outputs from observation""" | |
tool_outputs = {} | |
try: | |
if isinstance(observation, str): | |
data = json.loads(observation) | |
for key, value in data.items(): | |
if isinstance(value, str) and "llm_result" in value: | |
try: | |
tool_result = json.loads(value) | |
tool_outputs[key] = tool_result | |
except: | |
tool_outputs[key] = value | |
except: | |
pass | |
return tool_outputs | |
def format_message(message: str) -> Tuple[str, str]: | |
"""Format agent message for both terminal and XML output""" | |
# Terminal format | |
terminal_output = message.strip() | |
# XML format | |
root = ET.Element("agent_response") | |
msg_elem = ET.SubElement(root, "message") | |
msg_elem.text = message.strip() | |
xml_output = ET.tostring(root, encoding='unicode') | |
return terminal_output, xml_output | |
def format_error(error: str) -> Tuple[str, str]: | |
"""Format error message for both terminal and XML output""" | |
# Terminal format | |
terminal_output = f"Error: {error}" | |
# XML format | |
root = ET.Element("agent_response") | |
error_elem = ET.SubElement(root, "error") | |
error_elem.text = error | |
xml_output = ET.tostring(root, encoding='unicode') | |
return terminal_output, xml_output | |
def _clean_markdown(text: str) -> str: | |
"""Clean markdown formatting from text""" | |
text = re.sub(r'```.*?```', '', text, flags=re.DOTALL) | |
text = re.sub(r'[*_`#]', '', text) | |
return re.sub(r'\n{3,}', '\n\n', text.strip()) |