cc-api / response_formatter.py
Severian's picture
Upload 81 files
995af0f verified
raw
history blame
7.98 kB
from typing import Dict, Optional, Tuple, List, Any
import re
import xml.etree.ElementTree as ET
from datetime import datetime
import json
class ToolType:
DUCKDUCKGO = "duckduckgo_search"
REDDIT_NEWS = "reddit_x_gnews_newswire_crunchbase"
PUBMED = "pubmed_search"
CENSUS = "get_census_data"
HEATMAP = "heatmap_code"
MERMAID = "mermaid_diagram"
WISQARS = "wisqars"
WONDER = "wonder"
NCHS = "nchs"
ONESTEP = "onestep"
DQS = "dqs_nhis_adult_summary_health_statistics"
class ResponseFormatter:
@staticmethod
def format_thought(
thought: str,
observation: Optional[str] = None,
citations: List[Dict] = None,
metadata: Dict = None
) -> Tuple[str, str]:
"""Format agent thought and observation for both terminal and XML output"""
# Terminal format
terminal_output = thought.strip()
if observation:
cleaned_obs = ResponseFormatter._clean_markdown(observation)
if cleaned_obs:
terminal_output += f"\n\nObservation:\n{cleaned_obs}"
# XML format
root = ET.Element("agent_response")
thought_elem = ET.SubElement(root, "thought")
thought_elem.text = thought.strip()
if observation:
obs_elem = ET.SubElement(root, "observation")
# Extract and format tool outputs
tool_outputs = ResponseFormatter._extract_tool_outputs(observation)
if tool_outputs:
tools_elem = ET.SubElement(obs_elem, "tools")
for tool_name, tool_data in tool_outputs.items():
tool_elem = ResponseFormatter._create_tool_element(tools_elem, tool_name, tool_data)
# Add citations if available
if citations:
citations_elem = ET.SubElement(root, "citations")
for citation in citations:
cite_elem = ET.SubElement(citations_elem, "citation")
for key, value in citation.items():
cite_detail = ET.SubElement(cite_elem, key)
cite_detail.text = str(value)
# Add metadata if available
if metadata:
metadata_elem = ET.SubElement(root, "metadata")
for key, value in metadata.items():
meta_detail = ET.SubElement(metadata_elem, key)
meta_detail.text = str(value)
xml_output = ET.tostring(root, encoding='unicode')
return terminal_output, xml_output
@staticmethod
def _create_tool_element(parent: ET.Element, tool_name: str, tool_data: Dict) -> ET.Element:
"""Create XML element for specific tool type with appropriate structure"""
tool_elem = ET.SubElement(parent, "tool")
tool_elem.set("name", tool_name)
# Handle different tool types
if tool_name == ToolType.CENSUS:
ResponseFormatter._format_census_data(tool_elem, tool_data)
elif tool_name == ToolType.MERMAID:
ResponseFormatter._format_mermaid_data(tool_elem, tool_data)
elif tool_name in [ToolType.WISQARS, ToolType.WONDER, ToolType.NCHS]:
ResponseFormatter._format_health_data(tool_elem, tool_data)
else:
# Generic tool output format
content_elem = ET.SubElement(tool_elem, "content")
content_elem.text = ResponseFormatter._clean_markdown(str(tool_data))
return tool_elem
@staticmethod
def _format_census_data(tool_elem: ET.Element, data: Dict) -> None:
"""Format census data with specific structure"""
try:
# Extract census tract data
tracts_elem = ET.SubElement(tool_elem, "census_tracts")
# Parse the llm_result to extract structured data
if "llm_result" in data:
result = json.loads(data["llm_result"])
for tract_data in result.get("tracts", []):
tract_elem = ET.SubElement(tracts_elem, "tract")
tract_elem.set("id", str(tract_data.get("tract", "")))
# Add tract details
for key, value in tract_data.items():
if key != "tract":
detail_elem = ET.SubElement(tract_elem, key.replace("_", ""))
detail_elem.text = str(value)
except:
# Fallback to simple format if parsing fails
content_elem = ET.SubElement(tool_elem, "content")
content_elem.text = ResponseFormatter._clean_markdown(str(data))
@staticmethod
def _format_mermaid_data(tool_elem: ET.Element, data: Dict) -> None:
"""Format mermaid diagram data"""
try:
diagram_elem = ET.SubElement(tool_elem, "diagram")
if "mermaid_diagram" in data:
# Clean the mermaid code
mermaid_code = re.sub(r'```mermaid\s*|\s*```', '', data["mermaid_diagram"])
diagram_elem.text = mermaid_code.strip()
except:
content_elem = ET.SubElement(tool_elem, "content")
content_elem.text = ResponseFormatter._clean_markdown(str(data))
@staticmethod
def _format_health_data(tool_elem: ET.Element, data: Dict) -> None:
"""Format health-related data from WISQARS, WONDER, etc."""
try:
if isinstance(data, dict):
for key, value in data.items():
category_elem = ET.SubElement(tool_elem, key.replace("_", ""))
if isinstance(value, dict):
for sub_key, sub_value in value.items():
sub_elem = ET.SubElement(category_elem, sub_key.replace("_", ""))
sub_elem.text = str(sub_value)
else:
category_elem.text = str(value)
except:
content_elem = ET.SubElement(tool_elem, "content")
content_elem.text = ResponseFormatter._clean_markdown(str(data))
@staticmethod
def _extract_tool_outputs(observation: str) -> Dict[str, Any]:
"""Extract and clean tool outputs from observation"""
tool_outputs = {}
try:
if isinstance(observation, str):
data = json.loads(observation)
for key, value in data.items():
if isinstance(value, str) and "llm_result" in value:
try:
tool_result = json.loads(value)
tool_outputs[key] = tool_result
except:
tool_outputs[key] = value
except:
pass
return tool_outputs
@staticmethod
def format_message(message: str) -> Tuple[str, str]:
"""Format agent message for both terminal and XML output"""
# Terminal format
terminal_output = message.strip()
# XML format
root = ET.Element("agent_response")
msg_elem = ET.SubElement(root, "message")
msg_elem.text = message.strip()
xml_output = ET.tostring(root, encoding='unicode')
return terminal_output, xml_output
@staticmethod
def format_error(error: str) -> Tuple[str, str]:
"""Format error message for both terminal and XML output"""
# Terminal format
terminal_output = f"Error: {error}"
# XML format
root = ET.Element("agent_response")
error_elem = ET.SubElement(root, "error")
error_elem.text = error
xml_output = ET.tostring(root, encoding='unicode')
return terminal_output, xml_output
@staticmethod
def _clean_markdown(text: str) -> str:
"""Clean markdown formatting from text"""
text = re.sub(r'```.*?```', '', text, flags=re.DOTALL)
text = re.sub(r'[*_`#]', '', text)
return re.sub(r'\n{3,}', '\n\n', text.strip())