|
import streamlit as st |
|
import joblib |
|
import numpy as np |
|
from predict import extract_features |
|
import os |
|
import tempfile |
|
from huggingface_hub import hf_hub_download, list_repo_files |
|
import logging |
|
import traceback |
|
import sklearn |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.DEBUG, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
st.set_page_config( |
|
page_title="Healing Music Classifier", |
|
page_icon="🎵", |
|
layout="centered" |
|
) |
|
|
|
@st.cache_resource |
|
def load_model(): |
|
"""Load model from Hugging Face Hub""" |
|
try: |
|
|
|
logger.info(f"Using scikit-learn version: {sklearn.__version__}") |
|
st.write(f"Using scikit-learn version: {sklearn.__version__}") |
|
|
|
|
|
logger.info("Listing repository files...") |
|
try: |
|
files = list_repo_files("404Brain-Not-Found-yeah/healing-music-classifier") |
|
logger.info(f"Repository files: {files}") |
|
st.write("Available files in repository:", files) |
|
except Exception as e: |
|
logger.error(f"Error listing repository files: {str(e)}\n{traceback.format_exc()}") |
|
st.error(f"Error listing repository files: {str(e)}") |
|
return None, None |
|
|
|
|
|
os.makedirs("temp_models", exist_ok=True) |
|
logger.info("Created temp_models directory") |
|
|
|
logger.info("Downloading model from Hugging Face Hub...") |
|
|
|
try: |
|
model_path = hf_hub_download( |
|
repo_id="404Brain-Not-Found-yeah/healing-music-classifier", |
|
filename="models/model.joblib", |
|
local_dir="temp_models" |
|
) |
|
logger.info(f"Model downloaded to: {model_path}") |
|
st.write(f"Model downloaded to: {model_path}") |
|
except Exception as e: |
|
logger.error(f"Error downloading model: {str(e)}\n{traceback.format_exc()}") |
|
st.error(f"Error downloading model: {str(e)}") |
|
return None, None |
|
|
|
|
|
try: |
|
scaler_path = hf_hub_download( |
|
repo_id="404Brain-Not-Found-yeah/healing-music-classifier", |
|
filename="models/scaler.joblib", |
|
local_dir="temp_models" |
|
) |
|
logger.info(f"Scaler downloaded to: {scaler_path}") |
|
st.write(f"Scaler downloaded to: {scaler_path}") |
|
except Exception as e: |
|
logger.error(f"Error downloading scaler: {str(e)}\n{traceback.format_exc()}") |
|
st.error(f"Error downloading scaler: {str(e)}") |
|
return None, None |
|
|
|
|
|
try: |
|
logger.info("Loading model and scaler...") |
|
|
|
if not os.path.exists(model_path): |
|
logger.error(f"Model file not found at: {model_path}") |
|
st.error(f"Model file not found at: {model_path}") |
|
return None, None |
|
if not os.path.exists(scaler_path): |
|
logger.error(f"Scaler file not found at: {scaler_path}") |
|
st.error(f"Scaler file not found at: {scaler_path}") |
|
return None, None |
|
|
|
|
|
model_size = os.path.getsize(model_path) |
|
scaler_size = os.path.getsize(scaler_path) |
|
logger.info(f"Model file size: {model_size} bytes") |
|
logger.info(f"Scaler file size: {scaler_size} bytes") |
|
st.write(f"Model file size: {model_size} bytes") |
|
st.write(f"Scaler file size: {scaler_size} bytes") |
|
|
|
|
|
try: |
|
model = joblib.load(model_path) |
|
scaler = joblib.load(scaler_path) |
|
except Exception as load_error: |
|
logger.warning(f"Standard loading failed: {str(load_error)}") |
|
|
|
import pickle |
|
with open(model_path, 'rb') as f: |
|
model = pickle.load(f, encoding='latin1') |
|
with open(scaler_path, 'rb') as f: |
|
scaler = pickle.load(f, encoding='latin1') |
|
|
|
logger.info("Model and scaler loaded successfully") |
|
st.success("Model and scaler loaded successfully!") |
|
return model, scaler |
|
except Exception as e: |
|
logger.error(f"Error loading model/scaler files: {str(e)}\n{traceback.format_exc()}") |
|
st.error(f"Error loading model/scaler files: {str(e)}") |
|
return None, None |
|
|
|
except Exception as e: |
|
logger.error(f"Unexpected error in load_model: {str(e)}\n{traceback.format_exc()}") |
|
st.error(f"Unexpected error in load_model: {str(e)}") |
|
return None, None |
|
|
|
def main(): |
|
st.title("🎵 Healing Music Classifier") |
|
st.write(""" |
|
Upload your music file, and AI will analyze its healing potential! |
|
Supports mp3, wav formats. |
|
""") |
|
|
|
|
|
uploaded_file = st.file_uploader("Choose an audio file...", type=['mp3', 'wav']) |
|
|
|
if uploaded_file is not None: |
|
|
|
progress_bar = st.progress(0) |
|
status_text = st.empty() |
|
|
|
try: |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file: |
|
|
|
tmp_file.write(uploaded_file.getvalue()) |
|
tmp_file_path = tmp_file.name |
|
|
|
|
|
status_text.text("Analyzing music...") |
|
progress_bar.progress(30) |
|
|
|
|
|
model, scaler = load_model() |
|
if model is None or scaler is None: |
|
st.error("Model loading failed. Please check the logs for details.") |
|
return |
|
|
|
progress_bar.progress(50) |
|
|
|
|
|
features = extract_features(tmp_file_path) |
|
if features is None: |
|
st.error("Failed to extract audio features. Please ensure the file is a valid audio file.") |
|
return |
|
|
|
progress_bar.progress(70) |
|
|
|
|
|
try: |
|
scaled_features = scaler.transform([features]) |
|
healing_probability = model.predict_proba(scaled_features)[0][1] |
|
progress_bar.progress(90) |
|
except Exception as e: |
|
logger.error(f"Error during prediction: {str(e)}\n{traceback.format_exc()}") |
|
st.error(f"Error during prediction: {str(e)}") |
|
return |
|
|
|
|
|
st.subheader("Analysis Results") |
|
|
|
|
|
healing_percentage = healing_probability * 100 |
|
st.progress(healing_probability) |
|
|
|
|
|
st.write(f"Healing Index: {healing_percentage:.1f}%") |
|
|
|
|
|
if healing_percentage >= 75: |
|
st.success("This music has strong healing properties! 🌟") |
|
elif healing_percentage >= 50: |
|
st.info("This music has moderate healing effects. ✨") |
|
else: |
|
st.warning("This music has limited healing potential. 🎵") |
|
|
|
except Exception as e: |
|
st.error(f"An unexpected error occurred: {str(e)}") |
|
logger.exception("Unexpected error") |
|
|
|
finally: |
|
|
|
try: |
|
if 'tmp_file_path' in locals() and os.path.exists(tmp_file_path): |
|
os.unlink(tmp_file_path) |
|
except Exception as e: |
|
logger.error(f"Failed to clean up temporary file: {str(e)}") |
|
|
|
|
|
progress_bar.progress(100) |
|
status_text.text("Analysis complete!") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|