File size: 5,680 Bytes
4082be1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os
import numpy as np
import librosa
import pandas as pd
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
import joblib
import warnings
import soundfile as sf
import logging
import traceback
import sys

# 设置更详细的日志记录
logging.basicConfig(
    level=logging.DEBUG,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler('training.log')
    ]
)
logger = logging.getLogger(__name__)

warnings.filterwarnings('ignore')

def extract_features(file_path):
    """Extract audio features from a file."""
    try:
        logger.info(f"Starting feature extraction for: {file_path}")
        
        # Verify file exists
        if not os.path.exists(file_path):
            logger.error(f"File does not exist: {file_path}")
            return None
            
        # Verify file format
        try:
            with sf.SoundFile(file_path) as sf_file:
                logger.info(f"Audio file info: {sf_file.samplerate}Hz, {sf_file.channels} channels")
        except Exception as e:
            logger.error(f"Error reading audio file with soundfile: {str(e)}\n{traceback.format_exc()}")
            return None

        # Load audio file with error handling
        try:
            logger.info("Loading audio file...")
            y, sr = librosa.load(file_path, duration=30, sr=None)
            if len(y) == 0:
                logger.error("Audio file is empty")
                return None
            logger.info(f"Successfully loaded audio: {len(y)} samples, {sr}Hz sample rate")
        except Exception as e:
            logger.error(f"Error loading audio: {str(e)}\n{traceback.format_exc()}")
            return None

        # Ensure minimum duration
        duration = len(y) / sr
        logger.info(f"Audio duration: {duration:.2f} seconds")
        if duration < 1.0:
            logger.error("Audio file is too short (less than 1 second)")
            return None

        features_dict = {}
        
        try:
            # 1. MFCC (13 features x 2 = 26)
            logger.info("Extracting MFCC features...")
            mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
            features_dict['mfccs_mean'] = np.mean(mfccs, axis=1)
            features_dict['mfccs_var'] = np.var(mfccs, axis=1)
            logger.info(f"MFCC features shape: {mfccs.shape}")
        except Exception as e:
            logger.error(f"Error extracting MFCC: {str(e)}\n{traceback.format_exc()}")
            return None

        try:
            # 2. Chroma Features
            logger.info("Extracting chroma features...")
            chroma = librosa.feature.chroma_stft(y=y, sr=sr)
            features_dict['chroma'] = np.mean(chroma, axis=1)
            logger.info(f"Chroma features shape: {chroma.shape}")
        except Exception as e:
            logger.error(f"Error extracting chroma features: {str(e)}\n{traceback.format_exc()}")
            return None

        # Combine all features
        try:
            logger.info("Combining features...")
            features = np.concatenate([
                features_dict['mfccs_mean'],
                features_dict['mfccs_var'],
                features_dict['chroma']
            ])
            logger.info(f"Final feature vector shape: {features.shape}")
            return features
        except Exception as e:
            logger.error(f"Error combining features: {str(e)}\n{traceback.format_exc()}")
            return None

    except Exception as e:
        logger.error(f"Unexpected error in feature extraction: {str(e)}\n{traceback.format_exc()}")
        return None

def prepare_dataset():
    """Prepare dataset from healing and non-healing music folders."""
    # 直接使用合成数据集
    print("Using synthetic dataset for initial deployment...")
    np.random.seed(42)
    n_samples = 100  # 增加样本数量
    n_features = 38  # 26 MFCC features + 12 Chroma features
    
    # 创建更有结构的合成特征
    synthetic_features = np.random.normal(0, 1, (n_samples, n_features))
    # 创建平衡的标签
    synthetic_labels = np.concatenate([np.ones(n_samples//2), np.zeros(n_samples//2)])
    
    return synthetic_features, synthetic_labels

def train_and_evaluate_model():
    """Train and evaluate the model."""
    # Prepare dataset
    print("Extracting features from audio files...")
    X, y = prepare_dataset()
    
    # Scale features
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    
    # Split dataset
    X_train, X_test, y_train, y_test = train_test_split(
        X_scaled, y, test_size=0.2, random_state=42
    )
    
    # Train model
    print("Training model...")
    model = RandomForestClassifier(n_estimators=100, random_state=42)
    model.fit(X_train, y_train)
    
    # Evaluate model
    print("Evaluating model...")
    cv_scores = cross_val_score(model, X_scaled, y, cv=5)
    print(f"Cross-validation scores: {cv_scores}")
    print(f"Average CV score: {cv_scores.mean():.3f} (+/- {cv_scores.std() * 2:.3f})")
    
    # Save model and scaler
    print("Saving model and scaler...")
    model_dir = os.path.join(os.path.dirname(__file__), "models")
    os.makedirs(model_dir, exist_ok=True)
    
    model_path = os.path.join(model_dir, "model.joblib")
    scaler_path = os.path.join(model_dir, "scaler.joblib")
    
    joblib.dump(model, model_path)
    joblib.dump(scaler, scaler_path)
    
    return model, scaler

if __name__ == "__main__":
    train_and_evaluate_model()