litagin commited on
Commit
174f89c
1 Parent(s): d485dcb

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -7,21 +7,22 @@ import torch
7
  from models import AudioClassifier
8
  from utils import logger
9
 
 
 
10
 
11
  ckpt_dir = Path("ckpt/")
12
  config_path = ckpt_dir / "config.json"
13
  assert config_path.exists(), f"config.json not found in {ckpt_dir}"
14
  config = json.loads((ckpt_dir / "config.json").read_text())
15
- device = "cuda" if torch.cuda.is_available() else "cpu"
16
- model = AudioClassifier(device=device, **config["model"]).to(device)
17
 
 
18
  # Latest checkpoint
19
  if (ckpt_dir / "model_final.pth").exists():
20
  ckpt = ckpt_dir / "model_final.pth"
21
  else:
22
  ckpt = sorted(ckpt_dir.glob("*.pth"))[-1]
23
  logger.info(f"Loading {ckpt}...")
24
- model.load_state_dict(torch.load(ckpt))
25
 
26
 
27
  def classify_audio(audio_file: str):
 
7
  from models import AudioClassifier
8
  from utils import logger
9
 
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ logger.info(f"Device: {device}")
12
 
13
  ckpt_dir = Path("ckpt/")
14
  config_path = ckpt_dir / "config.json"
15
  assert config_path.exists(), f"config.json not found in {ckpt_dir}"
16
  config = json.loads((ckpt_dir / "config.json").read_text())
 
 
17
 
18
+ model = AudioClassifier(device=device, **config["model"]).to(device)
19
  # Latest checkpoint
20
  if (ckpt_dir / "model_final.pth").exists():
21
  ckpt = ckpt_dir / "model_final.pth"
22
  else:
23
  ckpt = sorted(ckpt_dir.glob("*.pth"))[-1]
24
  logger.info(f"Loading {ckpt}...")
25
+ model.load_state_dict(torch.load(ckpt, map_location=device))
26
 
27
 
28
  def classify_audio(audio_file: str):