ujalaarshad17 commited on
Commit
0474f44
·
1 Parent(s): f26cacc

Added files

Browse files
Files changed (1) hide show
  1. app.py +9 -21
app.py CHANGED
@@ -3,10 +3,9 @@ import torch
3
  import torchaudio
4
  import numpy as np
5
  import plotly.graph_objs as go
6
- import os # For file operations
7
- from pydub import AudioSegment # For audio format conversion
8
- from model import BoundaryDetectionModel
9
- from audio_dataset import pad_audio
10
 
11
  app = Flask(__name__)
12
 
@@ -16,17 +15,7 @@ model = BoundaryDetectionModel().to(device)
16
  model.load_state_dict(torch.load("checkpoint_epoch_21_eer_0.24.pth", map_location=device)["model_state_dict"])
17
  model.eval()
18
 
19
- def convert_to_wav(audio_path, temp_path="temp_audio.wav"):
20
- # Check if the file is already in .wav format
21
- if audio_path.lower().endswith(".wav"):
22
- return audio_path
23
- # Convert to .wav using pydub if it's not already in .wav
24
- audio = AudioSegment.from_file(audio_path)
25
- audio.export(temp_path, format="wav")
26
- return temp_path
27
-
28
  def preprocess_audio(audio_path, sample_rate=16000, target_length=8):
29
- # Load the audio waveform
30
  waveform, sr = torchaudio.load(audio_path)
31
  if sr != sample_rate:
32
  waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform)
@@ -52,10 +41,8 @@ def predict():
52
  if file.filename == '':
53
  return "No selected file", 400
54
 
55
- # Save the file to a temporary location and convert if necessary
56
- original_path = "temp_uploaded_audio"
57
- file.save(original_path)
58
- file_path = convert_to_wav(original_path) # Convert to .wav if needed
59
 
60
  # Preprocess audio and perform inference
61
  audio_tensor = preprocess_audio(file_path)
@@ -96,8 +83,9 @@ def predict():
96
  def return_to_index():
97
  # Delete temporary files before returning to index
98
  try:
99
- os.remove("temp_uploaded_audio") # Remove original uploaded audio file
100
- os.remove("temp_audio.wav") # Remove the converted .wav file if necessary
 
101
  except OSError as e:
102
  print(f"Error deleting temporary files: {e}")
103
 
@@ -169,4 +157,4 @@ def plot_fake_frames_waveform(output, prediction_flat, waveform, fake_frame_inte
169
  return plot_html
170
 
171
  if __name__ == '__main__':
172
- app.run()
 
3
  import torchaudio
4
  import numpy as np
5
  import plotly.graph_objs as go
6
+ import os # Import os for file operations
7
+ from model import BoundaryDetectionModel # Assuming your model is defined here
8
+ from audio_dataset import pad_audio # Assuming you have a function to pad audio
 
9
 
10
  app = Flask(__name__)
11
 
 
15
  model.load_state_dict(torch.load("checkpoint_epoch_21_eer_0.24.pth", map_location=device)["model_state_dict"])
16
  model.eval()
17
 
 
 
 
 
 
 
 
 
 
18
  def preprocess_audio(audio_path, sample_rate=16000, target_length=8):
 
19
  waveform, sr = torchaudio.load(audio_path)
20
  if sr != sample_rate:
21
  waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform)
 
41
  if file.filename == '':
42
  return "No selected file", 400
43
 
44
+ file_path = "temp_audio.wav" # Temporary file to store uploaded audio
45
+ file.save(file_path)
 
 
46
 
47
  # Preprocess audio and perform inference
48
  audio_tensor = preprocess_audio(file_path)
 
83
  def return_to_index():
84
  # Delete temporary files before returning to index
85
  try:
86
+ os.remove("temp_audio.wav") # Remove the temporary audio file
87
+ # If you have any other temporary files (like plots), remove them here too.
88
+ # Example: os.remove("temp_plot.html") if you save plots as HTML files.
89
  except OSError as e:
90
  print(f"Error deleting temporary files: {e}")
91
 
 
157
  return plot_html
158
 
159
  if __name__ == '__main__':
160
+ app.run()