gaepago_model / .ipynb_checkpoints /app-checkpoint.py
yumyeom's picture
app.py & etc ์ˆ˜์ •
5ebbb35
raw
history blame
1.77 kB
# Gaepago model V1 (CPU Test)
# import package
from transformers import AutoModelForAudioClassification
from transformers import AutoFeatureExtractor
from transformers import pipeline
from datasets import Dataset
import gradio as gr
import torch
# Set model & Dataset NM
MODEL_NAME = "Gae8J/gaepago-20"
DATASET_NAME = "Gae8J/modeling_v1"
# Import Model & feature extractor
model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME)
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
# ๋ชจ๋ธ cpu๋กœ ๋ณ€๊ฒฝํ•˜์—ฌ ์ง„ํ–‰
model.to("cpu")
# Gaepago Inference Model function
def gaepago_fn(tmp_audio_dir):
print(tmp_audio_dir)
audio_dataset = Dataset.from_dict({"audio": [tmp_audio_dir]}).cast_column("audio", Audio(sampling_rate=16000))
inputs = feature_extractor(audio_dataset[0]["audio"]["array"]
,sampling_rate=audio_dataset[0]["audio"]["sampling_rate"]
,return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
predicted_class_ids = torch.argmax(logits).item()
predicted_label = model.config.id2label[predicted_class_ids]
return predicted_label
# Main
main_api = gr.Blocks()
with main_api:
gr.Markdown("## 8J Gaepago Demo(with CPU)")
with gr.Row():
audio = gr.Audio(source="microphone", type="filepath"
,label='๋…น์Œ๋ฒ„ํŠผ์„ ๋ˆŒ๋Ÿฌ ์ดˆ์ฝ”๊ฐ€ ํ•˜๋Š” ๋ง์„ ๋“ค๋ ค์ฃผ์„ธ์š”')
transcription = gr.Textbox(label='์ง€๊ธˆ ์ดˆ์ฝ”๊ฐ€ ํ•˜๋Š” ๋ง์€...')
b1 = gr.Button("๊ฐ•์•„์ง€ ์–ธ์–ด ๋ฒˆ์—ญ!")
b1.click(gaepago_fn, inputs=audio, outputs=transcription)
# examples = gr.Examples(examples=example_list,
# inputs=[audio])
main_api.launch()