litagin's picture
Upload app.py
174f89c verified
raw
history blame
1.36 kB
import argparse
import json
from pathlib import Path
import gradio as gr
import torch
from models import AudioClassifier
from utils import logger
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Device: {device}")
ckpt_dir = Path("ckpt/")
config_path = ckpt_dir / "config.json"
assert config_path.exists(), f"config.json not found in {ckpt_dir}"
config = json.loads((ckpt_dir / "config.json").read_text())
model = AudioClassifier(device=device, **config["model"]).to(device)
# Latest checkpoint
if (ckpt_dir / "model_final.pth").exists():
ckpt = ckpt_dir / "model_final.pth"
else:
ckpt = sorted(ckpt_dir.glob("*.pth"))[-1]
logger.info(f"Loading {ckpt}...")
model.load_state_dict(torch.load(ckpt, map_location=device))
def classify_audio(audio_file: str):
logger.info(f"Classifying {audio_file}...")
output = model.infer_from_file(audio_file)
logger.success(f"Predicted: {output}")
return output
desc = """
# NSFW音声分類器
出力は以下の3つのクラスの確率です。
- usual: 通常の音声
- aegi: 喘ぎ声
- chupa: チュパ音(フェラやキス音声)
"""
with gr.Interface(
fn=classify_audio,
inputs=gr.Audio(label="Input audio", type="filepath"),
outputs=gr.Text(label="Classification"),
description=desc,
allow_flagging="never",
) as iface:
iface.launch()