Upload app.py
Browse files
app.py
CHANGED
@@ -7,21 +7,22 @@ import torch
|
|
7 |
from models import AudioClassifier
|
8 |
from utils import logger
|
9 |
|
|
|
|
|
10 |
|
11 |
ckpt_dir = Path("ckpt/")
|
12 |
config_path = ckpt_dir / "config.json"
|
13 |
assert config_path.exists(), f"config.json not found in {ckpt_dir}"
|
14 |
config = json.loads((ckpt_dir / "config.json").read_text())
|
15 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
-
model = AudioClassifier(device=device, **config["model"]).to(device)
|
17 |
|
|
|
18 |
# Latest checkpoint
|
19 |
if (ckpt_dir / "model_final.pth").exists():
|
20 |
ckpt = ckpt_dir / "model_final.pth"
|
21 |
else:
|
22 |
ckpt = sorted(ckpt_dir.glob("*.pth"))[-1]
|
23 |
logger.info(f"Loading {ckpt}...")
|
24 |
-
model.load_state_dict(torch.load(ckpt))
|
25 |
|
26 |
|
27 |
def classify_audio(audio_file: str):
|
|
|
7 |
from models import AudioClassifier
|
8 |
from utils import logger
|
9 |
|
10 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
+
logger.info(f"Device: {device}")
|
12 |
|
13 |
ckpt_dir = Path("ckpt/")
|
14 |
config_path = ckpt_dir / "config.json"
|
15 |
assert config_path.exists(), f"config.json not found in {ckpt_dir}"
|
16 |
config = json.loads((ckpt_dir / "config.json").read_text())
|
|
|
|
|
17 |
|
18 |
+
model = AudioClassifier(device=device, **config["model"]).to(device)
|
19 |
# Latest checkpoint
|
20 |
if (ckpt_dir / "model_final.pth").exists():
|
21 |
ckpt = ckpt_dir / "model_final.pth"
|
22 |
else:
|
23 |
ckpt = sorted(ckpt_dir.glob("*.pth"))[-1]
|
24 |
logger.info(f"Loading {ckpt}...")
|
25 |
+
model.load_state_dict(torch.load(ckpt, map_location=device))
|
26 |
|
27 |
|
28 |
def classify_audio(audio_file: str):
|