404Brain-Not-Found-yeah's picture
Update app.py
81ce00b verified
raw
history blame
8.11 kB
import streamlit as st
import joblib
import numpy as np
from predict import extract_features
import os
import tempfile
from huggingface_hub import hf_hub_download, list_repo_files
import logging
import traceback
import sklearn
# Set up logging
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Page configuration
st.set_page_config(
page_title="Healing Music Classifier",
page_icon="🎵",
layout="centered"
)
@st.cache_resource
def load_model():
"""Load model from Hugging Face Hub"""
try:
# 检查scikit-learn版本
logger.info(f"Using scikit-learn version: {sklearn.__version__}")
st.write(f"Using scikit-learn version: {sklearn.__version__}")
# 首先列出仓库中的所有文件
logger.info("Listing repository files...")
try:
files = list_repo_files("404Brain-Not-Found-yeah/healing-music-classifier")
logger.info(f"Repository files: {files}")
st.write("Available files in repository:", files)
except Exception as e:
logger.error(f"Error listing repository files: {str(e)}\n{traceback.format_exc()}")
st.error(f"Error listing repository files: {str(e)}")
return None, None
# 创建临时目录
os.makedirs("temp_models", exist_ok=True)
logger.info("Created temp_models directory")
logger.info("Downloading model from Hugging Face Hub...")
# 下载模型文件
try:
model_path = hf_hub_download(
repo_id="404Brain-Not-Found-yeah/healing-music-classifier",
filename="models/model.joblib",
local_dir="temp_models"
)
logger.info(f"Model downloaded to: {model_path}")
st.write(f"Model downloaded to: {model_path}")
except Exception as e:
logger.error(f"Error downloading model: {str(e)}\n{traceback.format_exc()}")
st.error(f"Error downloading model: {str(e)}")
return None, None
# 下载scaler文件
try:
scaler_path = hf_hub_download(
repo_id="404Brain-Not-Found-yeah/healing-music-classifier",
filename="models/scaler.joblib",
local_dir="temp_models"
)
logger.info(f"Scaler downloaded to: {scaler_path}")
st.write(f"Scaler downloaded to: {scaler_path}")
except Exception as e:
logger.error(f"Error downloading scaler: {str(e)}\n{traceback.format_exc()}")
st.error(f"Error downloading scaler: {str(e)}")
return None, None
# 加载模型文件
try:
logger.info("Loading model and scaler...")
# 检查文件是否存在
if not os.path.exists(model_path):
logger.error(f"Model file not found at: {model_path}")
st.error(f"Model file not found at: {model_path}")
return None, None
if not os.path.exists(scaler_path):
logger.error(f"Scaler file not found at: {scaler_path}")
st.error(f"Scaler file not found at: {scaler_path}")
return None, None
# 检查文件大小
model_size = os.path.getsize(model_path)
scaler_size = os.path.getsize(scaler_path)
logger.info(f"Model file size: {model_size} bytes")
logger.info(f"Scaler file size: {scaler_size} bytes")
st.write(f"Model file size: {model_size} bytes")
st.write(f"Scaler file size: {scaler_size} bytes")
# 尝试使用不同的pickle协议加载
try:
model = joblib.load(model_path)
scaler = joblib.load(scaler_path)
except Exception as load_error:
logger.warning(f"Standard loading failed: {str(load_error)}")
# 尝试使用兼容模式加载
import pickle
with open(model_path, 'rb') as f:
model = pickle.load(f, encoding='latin1')
with open(scaler_path, 'rb') as f:
scaler = pickle.load(f, encoding='latin1')
logger.info("Model and scaler loaded successfully")
st.success("Model and scaler loaded successfully!")
return model, scaler
except Exception as e:
logger.error(f"Error loading model/scaler files: {str(e)}\n{traceback.format_exc()}")
st.error(f"Error loading model/scaler files: {str(e)}")
return None, None
except Exception as e:
logger.error(f"Unexpected error in load_model: {str(e)}\n{traceback.format_exc()}")
st.error(f"Unexpected error in load_model: {str(e)}")
return None, None
def main():
st.title("🎵 Healing Music Classifier")
st.write("""
Upload your music file, and AI will analyze its healing potential!
Supports mp3, wav formats.
""")
# Add file upload component
uploaded_file = st.file_uploader("Choose an audio file...", type=['mp3', 'wav'])
if uploaded_file is not None:
# Create progress bar
progress_bar = st.progress(0)
status_text = st.empty()
try:
# Create temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
# Write uploaded file content
tmp_file.write(uploaded_file.getvalue())
tmp_file_path = tmp_file.name
# Update status
status_text.text("Analyzing music...")
progress_bar.progress(30)
# Load model
model, scaler = load_model()
if model is None or scaler is None:
st.error("Model loading failed. Please check the logs for details.")
return
progress_bar.progress(50)
# Extract features
features = extract_features(tmp_file_path)
if features is None:
st.error("Failed to extract audio features. Please ensure the file is a valid audio file.")
return
progress_bar.progress(70)
# Predict
try:
scaled_features = scaler.transform([features])
healing_probability = model.predict_proba(scaled_features)[0][1]
progress_bar.progress(90)
except Exception as e:
logger.error(f"Error during prediction: {str(e)}\n{traceback.format_exc()}")
st.error(f"Error during prediction: {str(e)}")
return
# Display results
st.subheader("Analysis Results")
# Create visualization progress bar
healing_percentage = healing_probability * 100
st.progress(healing_probability)
# Display percentage
st.write(f"Healing Index: {healing_percentage:.1f}%")
# Provide explanation
if healing_percentage >= 75:
st.success("This music has strong healing properties! 🌟")
elif healing_percentage >= 50:
st.info("This music has moderate healing effects. ✨")
else:
st.warning("This music has limited healing potential. 🎵")
except Exception as e:
st.error(f"An unexpected error occurred: {str(e)}")
logger.exception("Unexpected error")
finally:
# Clean up temporary file
try:
if 'tmp_file_path' in locals() and os.path.exists(tmp_file_path):
os.unlink(tmp_file_path)
except Exception as e:
logger.error(f"Failed to clean up temporary file: {str(e)}")
# Complete progress bar
progress_bar.progress(100)
status_text.text("Analysis complete!")
if __name__ == "__main__":
main()