cc-api / response_formatter.py
Severian's picture
Update response_formatter.py
2531f2f verified
raw
history blame
5.06 kB
from typing import Dict, Optional, Tuple, List, Any, Set
import re
import xml.etree.ElementTree as ET
from datetime import datetime
import json
import logging
# Setup logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# Create console handler if needed
if not logger.handlers:
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)
class StreamingFormatter:
def __init__(self):
self.processed_events = set()
self.current_tool_outputs = []
self.current_citations = []
self.current_metadata = {}
self.current_message_id = None
self.current_message_buffer = ""
def reset(self):
"""Reset the formatter state"""
self.processed_events.clear()
self.current_tool_outputs.clear()
self.current_citations.clear()
self.current_metadata.clear()
self.current_message_id = None
self.current_message_buffer = ""
def append_to_buffer(self, text: str):
"""Append text to the current message buffer"""
self.current_message_buffer += text
def get_and_clear_buffer(self) -> str:
"""Get the current buffer content and clear it"""
content = self.current_message_buffer
self.current_message_buffer = ""
return content
class ToolType:
DUCKDUCKGO = "duckduckgo_search"
REDDIT_NEWS = "reddit_x_gnews_newswire_crunchbase"
PUBMED = "pubmed_search"
CENSUS = "get_census_data"
HEATMAP = "heatmap_code"
MERMAID = "mermaid_diagram"
WISQARS = "wisqars"
WONDER = "wonder"
NCHS = "nchs"
ONESTEP = "onestep"
DQS = "dqs_nhis_adult_summary_health_statistics"
class ResponseFormatter:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(ResponseFormatter, cls).__new__(cls)
cls._instance.streaming_state = StreamingFormatter()
cls._instance.logger = logger
return cls._instance
def format_thought(
self,
thought: str,
observation: str,
tool_outputs: List[Dict] = None,
event_id: str = None,
message_id: str = None
) -> Optional[Tuple[str, str]]:
"""Format thought and tool outputs as XML"""
root = ET.Element("agent_response")
if thought:
thought_elem = ET.SubElement(root, "thought")
thought_elem.text = thought
if observation:
obs_elem = ET.SubElement(root, "observation")
obs_elem.text = observation
if tool_outputs:
tools_elem = ET.SubElement(root, "tool_outputs")
for output in tool_outputs:
tool_elem = ET.SubElement(tools_elem, "tool_output")
tool_elem.attrib["type"] = output.get("type", "")
tool_elem.text = output.get("content", "")
xml_output = ET.tostring(root, encoding='unicode')
return thought, xml_output
def format_message(
self,
message: str,
event_id: str = None,
message_id: str = None
) -> Optional[Tuple[str, str]]:
"""Format message as XML for frontend"""
if not message:
return None
root = ET.Element("agent_response")
msg_elem = ET.SubElement(root, "message")
msg_elem.text = message
xml_output = ET.tostring(root, encoding='unicode')
return message, xml_output
def format_error(
self,
error: str,
event_id: str = None,
message_id: str = None
) -> Optional[Tuple[str, str]]:
"""Format error message for both terminal and XML output"""
# Skip if already processed
if event_id and event_id in self.streaming_state.processed_events:
return None
# Handle message state
if message_id != self.streaming_state.current_message_id:
self.streaming_state.reset()
self.streaming_state.current_message_id = message_id
# Skip empty errors
if not error:
return None
# Terminal format
terminal_output = f"Error: {error}"
# XML format
root = ET.Element("agent_response")
error_elem = ET.SubElement(root, "error")
error_elem.text = error
xml_output = ET.tostring(root, encoding='unicode')
# Track processed event
if event_id:
self.streaming_state.processed_events.add(event_id)
return terminal_output, xml_output
@staticmethod
def _clean_markdown(text: str) -> str:
"""Clean markdown formatting from text"""
text = re.sub(r'```.*?```', '', text, flags=re.DOTALL)
text = re.sub(r'[*_`#]', '', text)
return re.sub(r'\n{3,}', '\n\n', text.strip())