from flask import Flask, request, render_template, redirect, url_for import torch import torchaudio import numpy as np import plotly.graph_objs as go import os # Import os for file operations from model import BoundaryDetectionModel # Assuming your model is defined here from audio_dataset import pad_audio # Assuming you have a function to pad audio app = Flask(__name__) # Load the pre-trained model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = BoundaryDetectionModel().to(device) model.load_state_dict(torch.load("checkpoint_epoch_21_eer_0.24.pth", map_location=device)["model_state_dict"]) model.eval() def preprocess_audio(audio_path, sample_rate=16000, target_length=8): waveform, sr = torchaudio.load(audio_path) if sr != sample_rate: waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform) waveform = pad_audio(waveform, sample_rate, target_length) return waveform.to(device) def infer_single_audio(audio_tensor): with torch.no_grad(): output = model(audio_tensor).squeeze(-1).cpu().numpy() prediction = (output > 0.5).astype(int) # Binary prediction for fake/real frames return output, prediction @app.route('/') def index(): return render_template('index.html') # HTML page for file upload and results display @app.route('/predict', methods=['POST']) def predict(): if 'file' not in request.files: return "No file uploaded", 400 file = request.files['file'] if file.filename == '': return "No selected file", 400 file_path = "temp_audio.wav" # Temporary file to store uploaded audio file.save(file_path) # Preprocess audio and perform inference audio_tensor = preprocess_audio(file_path) output, prediction = infer_single_audio(audio_tensor) # Flatten the prediction array to handle 2D structure prediction_flat = prediction.flatten() # Calculate total frames, fake frames, and fake percentage (formatted to 4 decimal places) total_frames = len(prediction_flat) fake_frame_count = int(np.sum(prediction_flat)) fake_percentage = round((fake_frame_count / total_frames) * 100, 4) result_type = 'Fake' if fake_frame_count >= 5 else 'Real' # Check if audio is classified as real if result_type == 'Real': fake_frame_intervals = "No Frame" # Set to "No Frame" if audio is real else: # Get precise fake frame timings with start and end times for fake frames fake_frame_intervals = get_fake_frame_intervals(prediction_flat, frame_duration=20) # Debug print to check intervals print("Fake Frame Intervals:", fake_frame_intervals) # Generate Plotly plot plot_html = plot_fake_frames_waveform(output, prediction_flat, audio_tensor.cpu().numpy(), fake_frame_intervals) # Render template with all results and plot return render_template('result.html', fake_percentage=fake_percentage, result_type=result_type, fake_frame_count=fake_frame_count, total_frames=total_frames, fake_frame_intervals=fake_frame_intervals, plot_html=plot_html) @app.route('/return', methods=['GET']) def return_to_index(): # Delete temporary files before returning to index try: os.remove("temp_audio.wav") # Remove the temporary audio file # If you have any other temporary files (like plots), remove them here too. # Example: os.remove("temp_plot.html") if you save plots as HTML files. except OSError as e: print(f"Error deleting temporary files: {e}") return redirect(url_for('index')) # Redirect back to the main page def get_fake_frame_intervals(prediction, frame_duration=20): """ Calculate start and end times in seconds for each consecutive fake frame interval. """ intervals = [] start_time = None for i, is_fake in enumerate(prediction): if is_fake == 1: if start_time is None: start_time = i * (frame_duration / 1000) # Convert ms to seconds else: if start_time is not None: end_time = i * (frame_duration / 1000) # End time of fake segment intervals.append((round(start_time, 4), round(end_time, 4))) start_time = None # Append last interval if it ended on the last frame if start_time is not None: end_time = len(prediction) * (frame_duration / 1000) # Final end time calculation intervals.append((round(start_time, 4), round(end_time, 4))) return intervals def plot_fake_frames_waveform(output, prediction_flat, waveform, fake_frame_intervals, frame_duration=20, sample_rate=16000): # Get actual audio duration from waveform for accurate x-axis scaling actual_duration = waveform.shape[1] / sample_rate num_samples = waveform.shape[1] # Get number of samples from the actual waveform time = np.linspace(0, actual_duration, num_samples) # Plotly trace for the waveform with different colors for fake and real frames frame_length = int(sample_rate * frame_duration / 1000) # Samples per frame traces = [] for i in range(len(prediction_flat)): start = i * frame_length end = min(start + frame_length, num_samples) # Ensure we do not exceed the samples color = 'rgba(255,0,0,0.8)' if prediction_flat[i] == 1 else 'rgba(0,128,0,0.5)' traces.append(go.Scatter( x=time[start:end], y=waveform[0][start:end], mode='lines', line=dict(color=color), showlegend=False )) # Full waveform view to show all fake and real segments min_time, max_time = 0, actual_duration # Layout settings for the plot layout = go.Layout( title="Audio Waveform with Fake Frames Highlighted", xaxis=dict(title="Time (seconds)", range=[min_time, max_time]), yaxis=dict(title="Amplitude"), autosize=True, template="plotly_white" ) fig = go.Figure(data=traces, layout=layout) # Convert Plotly figure to HTML plot_html = fig.to_html(full_html=False) return plot_html if __name__ == '__main__': app.run()