404Brain-Not-Found-yeah
commited on
Commit
•
f6dae0a
1
Parent(s):
1c5148f
Update app.py
Browse files
app.py
CHANGED
@@ -9,9 +9,9 @@ from huggingface_hub import hf_hub_download, list_repo_files
|
|
9 |
import logging
|
10 |
import traceback
|
11 |
import sklearn
|
|
|
12 |
|
13 |
# 版本检查
|
14 |
-
import pkg_resources
|
15 |
required_versions = {
|
16 |
'numpy': '1.23.5',
|
17 |
'scipy': '1.10.1',
|
@@ -23,17 +23,17 @@ def check_versions():
|
|
23 |
for package, required_version in required_versions.items():
|
24 |
try:
|
25 |
installed_version = pkg_resources.get_distribution(package).version
|
26 |
-
st.write(f"{package} version: {installed_version} (required: {required_version})")
|
27 |
if installed_version != required_version:
|
28 |
-
|
|
|
29 |
except pkg_resources.DistributionNotFound:
|
30 |
-
|
31 |
return False
|
32 |
return True
|
33 |
|
34 |
# Set up logging
|
35 |
logging.basicConfig(
|
36 |
-
level=logging.
|
37 |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
38 |
)
|
39 |
logger = logging.getLogger(__name__)
|
@@ -45,35 +45,18 @@ st.set_page_config(
|
|
45 |
layout="centered"
|
46 |
)
|
47 |
|
48 |
-
# 在加载模型之前检查版本
|
49 |
-
if not check_versions():
|
50 |
-
st.error("Package version requirements not met. Please check the logs.")
|
51 |
-
st.stop()
|
52 |
-
|
53 |
@st.cache_resource
|
54 |
def load_model():
|
55 |
"""Load model from Hugging Face Hub"""
|
56 |
try:
|
57 |
-
#
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
# 首先列出仓库中的所有文件
|
62 |
-
logger.info("Listing repository files...")
|
63 |
-
try:
|
64 |
-
files = list_repo_files("404Brain-Not-Found-yeah/healing-music-classifier")
|
65 |
-
logger.info(f"Repository files: {files}")
|
66 |
-
st.write("Available files in repository:", files)
|
67 |
-
except Exception as e:
|
68 |
-
logger.error(f"Error listing repository files: {str(e)}\n{traceback.format_exc()}")
|
69 |
-
st.error(f"Error listing repository files: {str(e)}")
|
70 |
return None, None
|
71 |
|
72 |
# 创建临时目录
|
73 |
os.makedirs("temp_models", exist_ok=True)
|
74 |
-
logger.info("Created temp_models directory")
|
75 |
|
76 |
-
logger.info("Downloading model from Hugging Face Hub...")
|
77 |
# 下载模型文件
|
78 |
try:
|
79 |
model_path = hf_hub_download(
|
@@ -81,48 +64,21 @@ def load_model():
|
|
81 |
filename="models/model.joblib",
|
82 |
local_dir="temp_models"
|
83 |
)
|
84 |
-
logger.info(f"Model downloaded to: {model_path}")
|
85 |
-
st.write(f"Model downloaded to: {model_path}")
|
86 |
-
except Exception as e:
|
87 |
-
logger.error(f"Error downloading model: {str(e)}\n{traceback.format_exc()}")
|
88 |
-
st.error(f"Error downloading model: {str(e)}")
|
89 |
-
return None, None
|
90 |
-
|
91 |
-
# 下载scaler文件
|
92 |
-
try:
|
93 |
scaler_path = hf_hub_download(
|
94 |
repo_id="404Brain-Not-Found-yeah/healing-music-classifier",
|
95 |
filename="models/scaler.joblib",
|
96 |
local_dir="temp_models"
|
97 |
)
|
98 |
-
logger.info(f"Scaler downloaded to: {scaler_path}")
|
99 |
-
st.write(f"Scaler downloaded to: {scaler_path}")
|
100 |
except Exception as e:
|
101 |
-
logger.error(f"Error downloading
|
102 |
-
st.error(f"Error downloading scaler: {str(e)}")
|
103 |
return None, None
|
104 |
|
105 |
# 加载模型文件
|
106 |
try:
|
107 |
-
|
108 |
-
|
109 |
-
if not os.path.exists(model_path):
|
110 |
-
logger.error(f"Model file not found at: {model_path}")
|
111 |
-
st.error(f"Model file not found at: {model_path}")
|
112 |
-
return None, None
|
113 |
-
if not os.path.exists(scaler_path):
|
114 |
-
logger.error(f"Scaler file not found at: {scaler_path}")
|
115 |
-
st.error(f"Scaler file not found at: {scaler_path}")
|
116 |
return None, None
|
117 |
|
118 |
-
# 检查文件大小
|
119 |
-
model_size = os.path.getsize(model_path)
|
120 |
-
scaler_size = os.path.getsize(scaler_path)
|
121 |
-
logger.info(f"Model file size: {model_size} bytes")
|
122 |
-
logger.info(f"Scaler file size: {scaler_size} bytes")
|
123 |
-
st.write(f"Model file size: {model_size} bytes")
|
124 |
-
st.write(f"Scaler file size: {scaler_size} bytes")
|
125 |
-
|
126 |
# 尝���使用不同的pickle协议加载
|
127 |
try:
|
128 |
model = joblib.load(model_path)
|
@@ -136,17 +92,13 @@ def load_model():
|
|
136 |
with open(scaler_path, 'rb') as f:
|
137 |
scaler = pickle.load(f, encoding='latin1')
|
138 |
|
139 |
-
logger.info("Model and scaler loaded successfully")
|
140 |
-
st.success("Model and scaler loaded successfully!")
|
141 |
return model, scaler
|
142 |
except Exception as e:
|
143 |
-
logger.error(f"Error loading model/scaler files: {str(e)}
|
144 |
-
st.error(f"Error loading model/scaler files: {str(e)}")
|
145 |
return None, None
|
146 |
|
147 |
except Exception as e:
|
148 |
-
logger.error(f"Unexpected error in load_model: {str(e)}
|
149 |
-
st.error(f"Unexpected error in load_model: {str(e)}")
|
150 |
return None, None
|
151 |
|
152 |
def main():
|
@@ -167,7 +119,6 @@ def main():
|
|
167 |
try:
|
168 |
# Create temporary file
|
169 |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
|
170 |
-
# Write uploaded file content
|
171 |
tmp_file.write(uploaded_file.getvalue())
|
172 |
tmp_file_path = tmp_file.name
|
173 |
|
@@ -178,7 +129,7 @@ def main():
|
|
178 |
# Load model
|
179 |
model, scaler = load_model()
|
180 |
if model is None or scaler is None:
|
181 |
-
st.error("Model loading failed. Please
|
182 |
return
|
183 |
|
184 |
progress_bar.progress(50)
|
@@ -197,8 +148,8 @@ def main():
|
|
197 |
healing_probability = model.predict_proba(scaled_features)[0][1]
|
198 |
progress_bar.progress(90)
|
199 |
except Exception as e:
|
200 |
-
logger.error(f"Error during prediction: {str(e)}
|
201 |
-
st.error(
|
202 |
return
|
203 |
|
204 |
# Display results
|
@@ -220,7 +171,7 @@ def main():
|
|
220 |
st.warning("This music has limited healing potential. 🎵")
|
221 |
|
222 |
except Exception as e:
|
223 |
-
st.error(
|
224 |
logger.exception("Unexpected error")
|
225 |
|
226 |
finally:
|
|
|
9 |
import logging
|
10 |
import traceback
|
11 |
import sklearn
|
12 |
+
import pkg_resources
|
13 |
|
14 |
# 版本检查
|
|
|
15 |
required_versions = {
|
16 |
'numpy': '1.23.5',
|
17 |
'scipy': '1.10.1',
|
|
|
23 |
for package, required_version in required_versions.items():
|
24 |
try:
|
25 |
installed_version = pkg_resources.get_distribution(package).version
|
|
|
26 |
if installed_version != required_version:
|
27 |
+
logger.warning(f"Warning: {package} version mismatch. Required: {required_version}, Installed: {installed_version}")
|
28 |
+
return False
|
29 |
except pkg_resources.DistributionNotFound:
|
30 |
+
logger.error(f"Error: {package} not found!")
|
31 |
return False
|
32 |
return True
|
33 |
|
34 |
# Set up logging
|
35 |
logging.basicConfig(
|
36 |
+
level=logging.INFO,
|
37 |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
38 |
)
|
39 |
logger = logging.getLogger(__name__)
|
|
|
45 |
layout="centered"
|
46 |
)
|
47 |
|
|
|
|
|
|
|
|
|
|
|
48 |
@st.cache_resource
|
49 |
def load_model():
|
50 |
"""Load model from Hugging Face Hub"""
|
51 |
try:
|
52 |
+
# 检查版本
|
53 |
+
if not check_versions():
|
54 |
+
logger.error("Package version requirements not met")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
return None, None
|
56 |
|
57 |
# 创建临时目录
|
58 |
os.makedirs("temp_models", exist_ok=True)
|
|
|
59 |
|
|
|
60 |
# 下载模型文件
|
61 |
try:
|
62 |
model_path = hf_hub_download(
|
|
|
64 |
filename="models/model.joblib",
|
65 |
local_dir="temp_models"
|
66 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
scaler_path = hf_hub_download(
|
68 |
repo_id="404Brain-Not-Found-yeah/healing-music-classifier",
|
69 |
filename="models/scaler.joblib",
|
70 |
local_dir="temp_models"
|
71 |
)
|
|
|
|
|
72 |
except Exception as e:
|
73 |
+
logger.error(f"Error downloading model files: {str(e)}")
|
|
|
74 |
return None, None
|
75 |
|
76 |
# 加载模型文件
|
77 |
try:
|
78 |
+
if not os.path.exists(model_path) or not os.path.exists(scaler_path):
|
79 |
+
logger.error("Model files not found")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
return None, None
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
# 尝���使用不同的pickle协议加载
|
83 |
try:
|
84 |
model = joblib.load(model_path)
|
|
|
92 |
with open(scaler_path, 'rb') as f:
|
93 |
scaler = pickle.load(f, encoding='latin1')
|
94 |
|
|
|
|
|
95 |
return model, scaler
|
96 |
except Exception as e:
|
97 |
+
logger.error(f"Error loading model/scaler files: {str(e)}")
|
|
|
98 |
return None, None
|
99 |
|
100 |
except Exception as e:
|
101 |
+
logger.error(f"Unexpected error in load_model: {str(e)}")
|
|
|
102 |
return None, None
|
103 |
|
104 |
def main():
|
|
|
119 |
try:
|
120 |
# Create temporary file
|
121 |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
|
|
|
122 |
tmp_file.write(uploaded_file.getvalue())
|
123 |
tmp_file_path = tmp_file.name
|
124 |
|
|
|
129 |
# Load model
|
130 |
model, scaler = load_model()
|
131 |
if model is None or scaler is None:
|
132 |
+
st.error("Model loading failed. Please try again later.")
|
133 |
return
|
134 |
|
135 |
progress_bar.progress(50)
|
|
|
148 |
healing_probability = model.predict_proba(scaled_features)[0][1]
|
149 |
progress_bar.progress(90)
|
150 |
except Exception as e:
|
151 |
+
logger.error(f"Error during prediction: {str(e)}")
|
152 |
+
st.error("Error during prediction. Please try again.")
|
153 |
return
|
154 |
|
155 |
# Display results
|
|
|
171 |
st.warning("This music has limited healing potential. 🎵")
|
172 |
|
173 |
except Exception as e:
|
174 |
+
st.error("An unexpected error occurred. Please try again.")
|
175 |
logger.exception("Unexpected error")
|
176 |
|
177 |
finally:
|