Spaces:
Sleeping
Sleeping
from flask import Flask, request, render_template, redirect, url_for | |
import torch | |
import torchaudio | |
import numpy as np | |
import plotly.graph_objs as go | |
import os # Import os for file operations | |
from model import BoundaryDetectionModel # Assuming your model is defined here | |
from audio_dataset import pad_audio # Assuming you have a function to pad audio | |
app = Flask(__name__) | |
# Load the pre-trained model | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = BoundaryDetectionModel().to(device) | |
model.load_state_dict(torch.load("checkpoint_epoch_21_eer_0.24.pth", map_location=device)["model_state_dict"]) | |
model.eval() | |
def preprocess_audio(audio_path, sample_rate=16000, target_length=8): | |
waveform, sr = torchaudio.load(audio_path) | |
if sr != sample_rate: | |
waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform) | |
waveform = pad_audio(waveform, sample_rate, target_length) | |
return waveform.to(device) | |
def infer_single_audio(audio_tensor): | |
with torch.no_grad(): | |
output = model(audio_tensor).squeeze(-1).cpu().numpy() | |
prediction = (output > 0.5).astype(int) # Binary prediction for fake/real frames | |
return output, prediction | |
def index(): | |
return render_template('index.html') # HTML page for file upload and results display | |
def predict(): | |
if 'file' not in request.files: | |
return "No file uploaded", 400 | |
file = request.files['file'] | |
if file.filename == '': | |
return "No selected file", 400 | |
file_path = "temp_audio.wav" # Temporary file to store uploaded audio | |
file.save(file_path) | |
# Preprocess audio and perform inference | |
audio_tensor = preprocess_audio(file_path) | |
output, prediction = infer_single_audio(audio_tensor) | |
# Flatten the prediction array to handle 2D structure | |
prediction_flat = prediction.flatten() | |
# Calculate total frames, fake frames, and fake percentage (formatted to 4 decimal places) | |
total_frames = len(prediction_flat) | |
fake_frame_count = int(np.sum(prediction_flat)) | |
fake_percentage = round((fake_frame_count / total_frames) * 100, 4) | |
result_type = 'Fake' if fake_frame_count >= 5 else 'Real' | |
# Check if audio is classified as real | |
if result_type == 'Real': | |
fake_frame_intervals = "No Frame" # Set to "No Frame" if audio is real | |
else: | |
# Get precise fake frame timings with start and end times for fake frames | |
fake_frame_intervals = get_fake_frame_intervals(prediction_flat, frame_duration=20) | |
# Debug print to check intervals | |
print("Fake Frame Intervals:", fake_frame_intervals) | |
# Generate Plotly plot | |
plot_html = plot_fake_frames_waveform(output, prediction_flat, audio_tensor.cpu().numpy(), fake_frame_intervals) | |
# Render template with all results and plot | |
return render_template('result.html', | |
fake_percentage=fake_percentage, | |
result_type=result_type, | |
fake_frame_count=fake_frame_count, | |
total_frames=total_frames, | |
fake_frame_intervals=fake_frame_intervals, | |
plot_html=plot_html) | |
def return_to_index(): | |
# Delete temporary files before returning to index | |
try: | |
os.remove("temp_audio.wav") # Remove the temporary audio file | |
# If you have any other temporary files (like plots), remove them here too. | |
# Example: os.remove("temp_plot.html") if you save plots as HTML files. | |
except OSError as e: | |
print(f"Error deleting temporary files: {e}") | |
return redirect(url_for('index')) # Redirect back to the main page | |
def get_fake_frame_intervals(prediction, frame_duration=20): | |
""" | |
Calculate start and end times in seconds for each consecutive fake frame interval. | |
""" | |
intervals = [] | |
start_time = None | |
for i, is_fake in enumerate(prediction): | |
if is_fake == 1: | |
if start_time is None: | |
start_time = i * (frame_duration / 1000) # Convert ms to seconds | |
else: | |
if start_time is not None: | |
end_time = i * (frame_duration / 1000) # End time of fake segment | |
intervals.append((round(start_time, 4), round(end_time, 4))) | |
start_time = None | |
# Append last interval if it ended on the last frame | |
if start_time is not None: | |
end_time = len(prediction) * (frame_duration / 1000) # Final end time calculation | |
intervals.append((round(start_time, 4), round(end_time, 4))) | |
return intervals | |
def plot_fake_frames_waveform(output, prediction_flat, waveform, fake_frame_intervals, frame_duration=20, sample_rate=16000): | |
# Get actual audio duration from waveform for accurate x-axis scaling | |
actual_duration = waveform.shape[1] / sample_rate | |
num_samples = waveform.shape[1] # Get number of samples from the actual waveform | |
time = np.linspace(0, actual_duration, num_samples) | |
# Plotly trace for the waveform with different colors for fake and real frames | |
frame_length = int(sample_rate * frame_duration / 1000) # Samples per frame | |
traces = [] | |
for i in range(len(prediction_flat)): | |
start = i * frame_length | |
end = min(start + frame_length, num_samples) # Ensure we do not exceed the samples | |
color = 'rgba(255,0,0,0.8)' if prediction_flat[i] == 1 else 'rgba(0,128,0,0.5)' | |
traces.append(go.Scatter( | |
x=time[start:end], | |
y=waveform[0][start:end], | |
mode='lines', | |
line=dict(color=color), | |
showlegend=False | |
)) | |
# Full waveform view to show all fake and real segments | |
min_time, max_time = 0, actual_duration | |
# Layout settings for the plot | |
layout = go.Layout( | |
title="Audio Waveform with Fake Frames Highlighted", | |
xaxis=dict(title="Time (seconds)", range=[min_time, max_time]), | |
yaxis=dict(title="Amplitude"), | |
autosize=True, | |
template="plotly_white" | |
) | |
fig = go.Figure(data=traces, layout=layout) | |
# Convert Plotly figure to HTML | |
plot_html = fig.to_html(full_html=False) | |
return plot_html | |
if __name__ == '__main__': | |
app.run() |