Emo-Fake-UI / app.py
ggirishg's picture
Update app.py
ce8c430 verified
import streamlit as st
import numpy as np
import torch
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
import torchaudio
from torchaudio.transforms import Resample
from tensorflow.keras.models import load_model
from tensorflow.keras.models import load_model
# Load models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/mms-1b")
wav2vec_model = Wav2Vec2Model.from_pretrained("facebook/mms-1b").to(device)
# h5_model_path = "model.h5"
# try:
# loaded_h5_model = load_model(h5_model_path)
# print("Model loaded successfully from .h5 format.")
# except Exception as e:
# print(f"Error: {e}")
saved_model_path = "model.h5"
try:
cnn_model = load_model(saved_model_path)
except Exception as e:
st.error(f"Error loading TensorFlow model: {e}")
st.stop()
# Preprocessing Function
def preprocess_audio(audio_path):
try:
waveform, sampling_rate = torchaudio.load(audio_path)
desired_sampling_rate = 16000
if sampling_rate != desired_sampling_rate:
resampler = Resample(sampling_rate, desired_sampling_rate)
waveform = resampler(waveform)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
return waveform, desired_sampling_rate
except Exception as e:
st.error(f"Error processing audio file: {e}")
return None, None
# Feature Extraction
def extract_features(audio_path, feature_extractor, wav2vec_model, device):
waveform, fs = preprocess_audio(audio_path)
if waveform is None:
return None
inputs = feature_extractor(waveform.squeeze().numpy(), sampling_rate=fs, return_tensors="pt").to(device)
with torch.no_grad():
outputs = wav2vec_model(**inputs)
embeddings = outputs.last_hidden_state.cpu().numpy()
avg_embeddings = np.mean(embeddings.squeeze(), axis=0)
return avg_embeddings
# Prediction
def predict_with_cnn(audio_path, cnn_model, feature_extractor, wav2vec_model, device):
features = extract_features(audio_path, feature_extractor, wav2vec_model, device)
if features is None:
return None, None, None
features = np.expand_dims(features, axis=0)
features = np.expand_dims(features, axis=2)
predictions = cnn_model.predict(features)
predicted_class = np.argmax(predictions, axis=1)
class_names = ["bonafide", "spoof"]
confidence = predictions[0][predicted_class[0]] # Extract confidence for predicted class
return class_names[predicted_class[0]], predictions[0], confidence
# Streamlit Application
st.set_page_config(page_title="🎡 Audio Spoof Detection", layout="wide")
st.title("🎡 Audio Spoof Detection")
st.markdown(
"""
This application uses advanced machine learning models to detect whether an audio file is **bonafide** (real) or **spoofed** (fake).
Upload a `.wav` file to get started!
"""
)
# File Upload
uploaded_file = st.file_uploader(
"Upload your audio file (WAV format only):", type=["wav"]
)
if uploaded_file:
# Save uploaded file to a temporary path
temp_file_path = "temp_audio.wav"
with open(temp_file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
# Display Audio Player
st.audio(temp_file_path, format="audio/wav")
# Processing Audio
st.write("🎧 **Processing the audio...**")
predicted_class, probabilities, confidence = predict_with_cnn(
temp_file_path, cnn_model, feature_extractor, wav2vec_model, device
)
# Display Results
if predicted_class:
col1, col2 = st.columns(2)
with col1:
st.markdown(
f"""
## πŸŽ‰ **Prediction: `{predicted_class.upper()}`**
"""
)
st.markdown(f"### **Confidence**: `{confidence:.2f}`")
with col2:
st.write("### Class Probabilities")
st.bar_chart(probabilities)
# Display Detailed Probabilities
st.markdown("### Class Details")
st.write(f"**Bonafide Probability**: `{probabilities[0]:.2f}`")
st.write(f"**Spoof Probability**: `{probabilities[1]:.2f}`")
else:
st.error("Failed to process the audio file. Please try again.")
else:
st.info("Please upload a `.wav` audio file to analyze.")