Merge branch 'main' of https://huggingface.co/spaces/anhnv125/FRN
Browse files- README.md +4 -2
- app.py +11 -7
- requirements.txt +1 -1
README.md
CHANGED
@@ -3,8 +3,10 @@ title: FRN
|
|
3 |
emoji: π
|
4 |
colorFrom: gray
|
5 |
colorTo: red
|
6 |
-
sdk:
|
7 |
-
pinned:
|
|
|
|
|
8 |
---
|
9 |
|
10 |
# FRN - Full-band Recurrent Network Official Implementation
|
|
|
3 |
emoji: π
|
4 |
colorFrom: gray
|
5 |
colorTo: red
|
6 |
+
sdk: streamlit
|
7 |
+
pinned: true
|
8 |
+
app_file: app.py
|
9 |
+
sdk_version: 1.10.0
|
10 |
---
|
11 |
|
12 |
# FRN - Full-band Recurrent Network Official Implementation
|
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import streamlit as st
|
2 |
import librosa
|
|
|
3 |
import librosa.display
|
4 |
from config import CONFIG
|
5 |
import torch
|
@@ -9,7 +10,7 @@ import matplotlib.pyplot as plt
|
|
9 |
import numpy as np
|
10 |
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
11 |
|
12 |
-
@st.
|
13 |
def load_model():
|
14 |
path = 'lightning_logs/version_0/checkpoints/frn.onnx'
|
15 |
onnx_model = onnx.load(path)
|
@@ -87,9 +88,9 @@ target = target[:packet_size * (len(target) // packet_size)]
|
|
87 |
st.subheader('Original audio')
|
88 |
st.audio(uploaded_file)
|
89 |
|
90 |
-
st.subheader('Choose
|
91 |
-
|
92 |
-
loss_percent = float(
|
93 |
mask_gen = MaskGenerator(is_train=False, probs=[(1 - loss_percent, loss_percent)])
|
94 |
lossy_input = target.copy().reshape(-1, packet_size)
|
95 |
mask = mask_gen.gen_mask(len(lossy_input), seed=0)[:, np.newaxis]
|
@@ -109,9 +110,12 @@ if st.button('Conceal lossy audio!'):
|
|
109 |
fig = visualize(target, lossy_input, output)
|
110 |
st.pyplot(fig)
|
111 |
st.success('Done!')
|
|
|
|
|
|
|
112 |
st.text('Original audio')
|
113 |
-
st.audio(target
|
114 |
st.text('Lossy audio')
|
115 |
-
st.audio(
|
116 |
st.text('Enhanced audio')
|
117 |
-
st.audio(
|
|
|
1 |
import streamlit as st
|
2 |
import librosa
|
3 |
+
import soundfile as sf
|
4 |
import librosa.display
|
5 |
from config import CONFIG
|
6 |
import torch
|
|
|
10 |
import numpy as np
|
11 |
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
12 |
|
13 |
+
@st.cache
|
14 |
def load_model():
|
15 |
path = 'lightning_logs/version_0/checkpoints/frn.onnx'
|
16 |
onnx_model = onnx.load(path)
|
|
|
88 |
st.subheader('Original audio')
|
89 |
st.audio(uploaded_file)
|
90 |
|
91 |
+
st.subheader('Choose expected packet loss rate')
|
92 |
+
slider = [st.slider("Expected loss rate for Markov Chain loss generator", 0, 100, step=1)]
|
93 |
+
loss_percent = float(slider[0])/100
|
94 |
mask_gen = MaskGenerator(is_train=False, probs=[(1 - loss_percent, loss_percent)])
|
95 |
lossy_input = target.copy().reshape(-1, packet_size)
|
96 |
mask = mask_gen.gen_mask(len(lossy_input), seed=0)[:, np.newaxis]
|
|
|
110 |
fig = visualize(target, lossy_input, output)
|
111 |
st.pyplot(fig)
|
112 |
st.success('Done!')
|
113 |
+
sf.write('target.wav', target, sr)
|
114 |
+
sf.write('lossy.wav', lossy_input, sr)
|
115 |
+
sf.write('enhanced.wav', output, sr)
|
116 |
st.text('Original audio')
|
117 |
+
st.audio('target.wav')
|
118 |
st.text('Lossy audio')
|
119 |
+
st.audio('lossy.wav')
|
120 |
st.text('Enhanced audio')
|
121 |
+
st.audio('enhanced.wav')
|
requirements.txt
CHANGED
@@ -12,6 +12,6 @@ soundfile==0.11.0
|
|
12 |
torch==1.13.1
|
13 |
torchmetrics==0.11.0
|
14 |
tqdm==4.64.0
|
15 |
-
|
16 |
pesq==0.0.4
|
17 |
onnx==1.13.0
|
|
|
12 |
torch==1.13.1
|
13 |
torchmetrics==0.11.0
|
14 |
tqdm==4.64.0
|
15 |
+
pystoi==0.3.3
|
16 |
pesq==0.0.4
|
17 |
onnx==1.13.0
|