AudioSpoofing / app.py
ujalaarshad17's picture
Added files
0474f44
from flask import Flask, request, render_template, redirect, url_for
import torch
import torchaudio
import numpy as np
import plotly.graph_objs as go
import os # Import os for file operations
from model import BoundaryDetectionModel # Assuming your model is defined here
from audio_dataset import pad_audio # Assuming you have a function to pad audio
app = Flask(__name__)
# Load the pre-trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BoundaryDetectionModel().to(device)
model.load_state_dict(torch.load("checkpoint_epoch_21_eer_0.24.pth", map_location=device)["model_state_dict"])
model.eval()
def preprocess_audio(audio_path, sample_rate=16000, target_length=8):
waveform, sr = torchaudio.load(audio_path)
if sr != sample_rate:
waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform)
waveform = pad_audio(waveform, sample_rate, target_length)
return waveform.to(device)
def infer_single_audio(audio_tensor):
with torch.no_grad():
output = model(audio_tensor).squeeze(-1).cpu().numpy()
prediction = (output > 0.5).astype(int) # Binary prediction for fake/real frames
return output, prediction
@app.route('/')
def index():
return render_template('index.html') # HTML page for file upload and results display
@app.route('/predict', methods=['POST'])
def predict():
if 'file' not in request.files:
return "No file uploaded", 400
file = request.files['file']
if file.filename == '':
return "No selected file", 400
file_path = "temp_audio.wav" # Temporary file to store uploaded audio
file.save(file_path)
# Preprocess audio and perform inference
audio_tensor = preprocess_audio(file_path)
output, prediction = infer_single_audio(audio_tensor)
# Flatten the prediction array to handle 2D structure
prediction_flat = prediction.flatten()
# Calculate total frames, fake frames, and fake percentage (formatted to 4 decimal places)
total_frames = len(prediction_flat)
fake_frame_count = int(np.sum(prediction_flat))
fake_percentage = round((fake_frame_count / total_frames) * 100, 4)
result_type = 'Fake' if fake_frame_count >= 5 else 'Real'
# Check if audio is classified as real
if result_type == 'Real':
fake_frame_intervals = "No Frame" # Set to "No Frame" if audio is real
else:
# Get precise fake frame timings with start and end times for fake frames
fake_frame_intervals = get_fake_frame_intervals(prediction_flat, frame_duration=20)
# Debug print to check intervals
print("Fake Frame Intervals:", fake_frame_intervals)
# Generate Plotly plot
plot_html = plot_fake_frames_waveform(output, prediction_flat, audio_tensor.cpu().numpy(), fake_frame_intervals)
# Render template with all results and plot
return render_template('result.html',
fake_percentage=fake_percentage,
result_type=result_type,
fake_frame_count=fake_frame_count,
total_frames=total_frames,
fake_frame_intervals=fake_frame_intervals,
plot_html=plot_html)
@app.route('/return', methods=['GET'])
def return_to_index():
# Delete temporary files before returning to index
try:
os.remove("temp_audio.wav") # Remove the temporary audio file
# If you have any other temporary files (like plots), remove them here too.
# Example: os.remove("temp_plot.html") if you save plots as HTML files.
except OSError as e:
print(f"Error deleting temporary files: {e}")
return redirect(url_for('index')) # Redirect back to the main page
def get_fake_frame_intervals(prediction, frame_duration=20):
"""
Calculate start and end times in seconds for each consecutive fake frame interval.
"""
intervals = []
start_time = None
for i, is_fake in enumerate(prediction):
if is_fake == 1:
if start_time is None:
start_time = i * (frame_duration / 1000) # Convert ms to seconds
else:
if start_time is not None:
end_time = i * (frame_duration / 1000) # End time of fake segment
intervals.append((round(start_time, 4), round(end_time, 4)))
start_time = None
# Append last interval if it ended on the last frame
if start_time is not None:
end_time = len(prediction) * (frame_duration / 1000) # Final end time calculation
intervals.append((round(start_time, 4), round(end_time, 4)))
return intervals
def plot_fake_frames_waveform(output, prediction_flat, waveform, fake_frame_intervals, frame_duration=20, sample_rate=16000):
# Get actual audio duration from waveform for accurate x-axis scaling
actual_duration = waveform.shape[1] / sample_rate
num_samples = waveform.shape[1] # Get number of samples from the actual waveform
time = np.linspace(0, actual_duration, num_samples)
# Plotly trace for the waveform with different colors for fake and real frames
frame_length = int(sample_rate * frame_duration / 1000) # Samples per frame
traces = []
for i in range(len(prediction_flat)):
start = i * frame_length
end = min(start + frame_length, num_samples) # Ensure we do not exceed the samples
color = 'rgba(255,0,0,0.8)' if prediction_flat[i] == 1 else 'rgba(0,128,0,0.5)'
traces.append(go.Scatter(
x=time[start:end],
y=waveform[0][start:end],
mode='lines',
line=dict(color=color),
showlegend=False
))
# Full waveform view to show all fake and real segments
min_time, max_time = 0, actual_duration
# Layout settings for the plot
layout = go.Layout(
title="Audio Waveform with Fake Frames Highlighted",
xaxis=dict(title="Time (seconds)", range=[min_time, max_time]),
yaxis=dict(title="Amplitude"),
autosize=True,
template="plotly_white"
)
fig = go.Figure(data=traces, layout=layout)
# Convert Plotly figure to HTML
plot_html = fig.to_html(full_html=False)
return plot_html
if __name__ == '__main__':
app.run()