MERaLiON-AudioLLM / src /logger.py
YingxuHe's picture
update demo
8573823
import io
import os
import time
import json
from threading import Thread, Lock
import streamlit as st
from huggingface_hub import HfApi
from src.utils import get_current_strftime
logger_lock = Lock()
def threaded(fn):
def wrapper(*args, **kwargs):
thread = Thread(target=fn, args=args, kwargs=kwargs)
thread.start()
return thread
return wrapper
class Logger:
def __init__(self):
self.app_id = get_current_strftime()
self.session_increment = 0
self.query_increment = 0
self.sync_interval = 180
self.session_data = []
self.query_data = []
self.audio_data = []
self.sync_data()
def register_session(self) -> str:
new_session_id = f"{self.app_id}+{self.session_increment}"
with logger_lock:
self.session_data.append({
"session_id": new_session_id,
"creation_time": get_current_strftime()
})
self.session_increment += 1
return new_session_id
def register_query(self,
session_id,
base64_audio,
text_input,
response,
**kwargs
):
new_query_id = self.query_increment
current_time = get_current_strftime()
with logger_lock:
current_query_data = {
"session_id": session_id,
"query_id": new_query_id,
"creation_time": current_time,
"text": text_input,
"response": response,
}
current_query_data.update(kwargs)
self.query_data.append(current_query_data)
self.audio_data.append({
"session_id": session_id,
"query_id": new_query_id,
"creation_time": current_time,
"audio": base64_audio,
})
self.query_increment += 1
@threaded
def sync_data(self):
api = HfApi()
while True:
time.sleep(self.sync_interval)
for data_name in ["session_data", "query_data", "audio_data"]:
with logger_lock:
last_data = getattr(self, data_name, [])
setattr(self, data_name, [])
if not last_data:
continue
buffer = io.BytesIO()
for row in last_data:
row_str = json.dumps(row, ensure_ascii=False)+"\n"
buffer.write(row_str.encode("utf-8"))
api.upload_file(
path_or_fileobj=buffer,
path_in_repo=f"{data_name}/{get_current_strftime()}.json",
repo_id=os.getenv("LOGGING_REPO_NAME"),
repo_type="dataset",
token=os.getenv('HF_TOKEN')
)
buffer.close()
@st.cache_resource()
def load_logger():
return Logger()