Spaces:
Running
on
Zero
Running
on
Zero
from collections import deque | |
from datetime import datetime | |
import io | |
import logging | |
import sys | |
import threading | |
logs = None | |
stdout_interceptor = None | |
stderr_interceptor = None | |
class LogInterceptor(io.TextIOWrapper): | |
def __init__(self, stream, *args, **kwargs): | |
buffer = stream.buffer | |
encoding = stream.encoding | |
super().__init__(buffer, *args, **kwargs, encoding=encoding, line_buffering=stream.line_buffering) | |
self._lock = threading.Lock() | |
self._flush_callbacks = [] | |
self._logs_since_flush = [] | |
def write(self, data): | |
entry = {"t": datetime.now().isoformat(), "m": data} | |
with self._lock: | |
self._logs_since_flush.append(entry) | |
# Simple handling for cr to overwrite the last output if it isnt a full line | |
# else logs just get full of progress messages | |
if isinstance(data, str) and data.startswith("\r") and not logs[-1]["m"].endswith("\n"): | |
logs.pop() | |
logs.append(entry) | |
super().write(data) | |
def flush(self): | |
super().flush() | |
for cb in self._flush_callbacks: | |
cb(self._logs_since_flush) | |
self._logs_since_flush = [] | |
def on_flush(self, callback): | |
self._flush_callbacks.append(callback) | |
def get_logs(): | |
return logs | |
def on_flush(callback): | |
if stdout_interceptor is not None: | |
stdout_interceptor.on_flush(callback) | |
if stderr_interceptor is not None: | |
stderr_interceptor.on_flush(callback) | |
def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool = False): | |
global logs | |
if logs: | |
return | |
# Override output streams and log to buffer | |
logs = deque(maxlen=capacity) | |
global stdout_interceptor | |
global stderr_interceptor | |
stdout_interceptor = sys.stdout = LogInterceptor(sys.stdout) | |
stderr_interceptor = sys.stderr = LogInterceptor(sys.stderr) | |
# Setup default global logger | |
logger = logging.getLogger() | |
logger.setLevel(log_level) | |
stream_handler = logging.StreamHandler() | |
stream_handler.setFormatter(logging.Formatter("%(message)s")) | |
if use_stdout: | |
# Only errors and critical to stderr | |
stream_handler.addFilter(lambda record: not record.levelno < logging.ERROR) | |
# Lesser to stdout | |
stdout_handler = logging.StreamHandler(sys.stdout) | |
stdout_handler.setFormatter(logging.Formatter("%(message)s")) | |
stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR) | |
logger.addHandler(stdout_handler) | |
logger.addHandler(stream_handler) | |