File size: 4,389 Bytes
43d32f6
 
 
 
 
 
 
ce8c430
43d32f6
 
 
 
 
ce8c430
 
 
 
 
 
 
 
 
 
 
 
 
 
43d32f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import streamlit as st
import numpy as np
import torch
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
import torchaudio
from torchaudio.transforms import Resample
from tensorflow.keras.models import load_model
from tensorflow.keras.models import load_model
# Load models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/mms-1b")
wav2vec_model = Wav2Vec2Model.from_pretrained("facebook/mms-1b").to(device)




# h5_model_path = "model.h5"

# try:
#     loaded_h5_model = load_model(h5_model_path)
#     print("Model loaded successfully from .h5 format.")
# except Exception as e:
#     print(f"Error: {e}")



saved_model_path = "model.h5"
try:
    cnn_model = load_model(saved_model_path)
except Exception as e:
    st.error(f"Error loading TensorFlow model: {e}")
    st.stop()

# Preprocessing Function
def preprocess_audio(audio_path):
    try:
        waveform, sampling_rate = torchaudio.load(audio_path)
        desired_sampling_rate = 16000
        if sampling_rate != desired_sampling_rate:
            resampler = Resample(sampling_rate, desired_sampling_rate)
            waveform = resampler(waveform)
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
        return waveform, desired_sampling_rate
    except Exception as e:
        st.error(f"Error processing audio file: {e}")
        return None, None

# Feature Extraction
def extract_features(audio_path, feature_extractor, wav2vec_model, device):
    waveform, fs = preprocess_audio(audio_path)
    if waveform is None:
        return None
    inputs = feature_extractor(waveform.squeeze().numpy(), sampling_rate=fs, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = wav2vec_model(**inputs)
        embeddings = outputs.last_hidden_state.cpu().numpy()
    avg_embeddings = np.mean(embeddings.squeeze(), axis=0)
    return avg_embeddings

# Prediction
def predict_with_cnn(audio_path, cnn_model, feature_extractor, wav2vec_model, device):
    features = extract_features(audio_path, feature_extractor, wav2vec_model, device)
    if features is None:
        return None, None, None
    features = np.expand_dims(features, axis=0)
    features = np.expand_dims(features, axis=2)
    predictions = cnn_model.predict(features)
    predicted_class = np.argmax(predictions, axis=1)
    class_names = ["bonafide", "spoof"]
    confidence = predictions[0][predicted_class[0]]  # Extract confidence for predicted class
    return class_names[predicted_class[0]], predictions[0], confidence

# Streamlit Application
st.set_page_config(page_title="🎡 Audio Spoof Detection", layout="wide")
st.title("🎡 Audio Spoof Detection")
st.markdown(
    """
    This application uses advanced machine learning models to detect whether an audio file is **bonafide** (real) or **spoofed** (fake).
    Upload a `.wav` file to get started!
    """
)

# File Upload
uploaded_file = st.file_uploader(
    "Upload your audio file (WAV format only):", type=["wav"]
)

if uploaded_file:
    # Save uploaded file to a temporary path
    temp_file_path = "temp_audio.wav"
    with open(temp_file_path, "wb") as f:
        f.write(uploaded_file.getbuffer())
    
    # Display Audio Player
    st.audio(temp_file_path, format="audio/wav")

    # Processing Audio
    st.write("🎧 **Processing the audio...**")
    predicted_class, probabilities, confidence = predict_with_cnn(
        temp_file_path, cnn_model, feature_extractor, wav2vec_model, device
    )
    
    # Display Results
    if predicted_class:
        col1, col2 = st.columns(2)
        
        with col1:
            st.markdown(
                f"""
                ## πŸŽ‰ **Prediction: `{predicted_class.upper()}`**
                """
            )
            st.markdown(f"### **Confidence**: `{confidence:.2f}`")

        with col2:
            st.write("### Class Probabilities")
            st.bar_chart(probabilities)

        # Display Detailed Probabilities
        st.markdown("### Class Details")
        st.write(f"**Bonafide Probability**: `{probabilities[0]:.2f}`")
        st.write(f"**Spoof Probability**: `{probabilities[1]:.2f}`")

    else:
        st.error("Failed to process the audio file. Please try again.")
else:
    st.info("Please upload a `.wav` audio file to analyze.")