import os import glob import torchaudio import torchaudio.transforms as T import numpy as np from matplotlib import pyplot as plt import librosa import librosa.display from df import enhance, init_df import streamlit as st from streamlit.components.v1 import html app_title = "소음 억제 도구" model, df_state, _ = init_df() # Load default model df_sr = 48000 def display_audio_info(audio, title): # 두 개의 컬럼 생성 col1, col2 = st.columns(2) audio = np.clip(audio, -1.0, 1.0) if len(np.shape(audio)) == 2: audio = audio[0] # 왼쪽 컬럼에 스펙트로그램 표시 with col1: st.markdown(f"### {title} - Spectrogram") D = librosa.stft(audio) # STFT of y S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max) fig, ax = plt.subplots() img = librosa.display.specshow( S_db, x_axis='time', y_axis='linear', ax=ax) fig.colorbar(img, ax=ax, format="%+2.f dB") st.pyplot(fig) # 오른쪽 컬럼에 파형 표시 with col2: st.markdown(f"### {title} - Waveform") fig, ax = plt.subplots() plt.plot(audio) ax.set_xticks([]) ax.set_ylim(-1, 1) st.pyplot(fig) def main(): st.set_page_config(page_title=app_title, page_icon="favicon.ico", layout="centered", initial_sidebar_state="auto", menu_items=None) button = """""" st.title(app_title) st.divider() st.header('손쉽게 불필요한 소음을 제거하세요!') uploaded_file = st.file_uploader( "변환할 파일을 업로드 해주세요. (지원 형식: .wav, .mp3, .opus)") if uploaded_file: # 이전에 다운로드 한 파일을 삭제 files_to_remove = glob.glob('enhanced_*') for file in files_to_remove: os.remove(file) uploaded_file_type = uploaded_file.type.split('/')[-1] print(uploaded_file_type) if uploaded_file_type not in ['wav', 'mpeg', 'ogg']: st.text('지원하지 않는 파일 형식입니다.') else: with st.spinner('소음 제거하는 중'): noisy_audio, sr = torchaudio.load(uploaded_file) print("np.shape(noisy_audio)", np.shape(noisy_audio)) st.audio(noisy_audio.numpy(), sample_rate=sr) # 샘플링 레이트가 48000Hz가 아닐 경우 리샘플링 if sr != df_sr: resampler = T.Resample(orig_freq=sr, new_freq=df_sr) noisy_audio = resampler(noisy_audio) display_audio_info(noisy_audio.numpy(), "입력") with st.spinner('소음 제거하는 중'): output_audio = enhance(model, df_state, noisy_audio) enhanced_audio = output_audio st.divider() # 샘플링 레이트가 48000Hz가 아닐 경우 리샘플링 if sr != df_sr: resampler = T.Resample(orig_freq=df_sr, new_freq=sr) enhanced_audio = resampler(enhanced_audio) st.audio(enhanced_audio.numpy(), sample_rate=sr) display_audio_info(output_audio.numpy(), "출력") html(button, height=70, width=240) st.markdown( """ """, unsafe_allow_html=True, ) if __name__ == '__main__': main()