Arnab Das commited on
Commit
5fde11f
·
1 Parent(s): 8773f6c
Files changed (2) hide show
  1. app_backup.py +0 -129
  2. models.py +2 -1
app_backup.py DELETED
@@ -1,129 +0,0 @@
1
- import torch
2
- import gradio as gr
3
- import models as MOD
4
- import process_data as PD
5
- from transformers import pipeline
6
- from manipulate_model.utils import get_config_and_model, infere
7
-
8
- model_master = {
9
- "SSL-AASIST (Trained on ASV-Spoof5)": {"eer_threshold": 3.3330237865448,
10
- "data_process_func": "process_ssl_assist_input",
11
- "note": "This model is trained only on ASVSpoof 2024 training data.",
12
- "model_class": "Model",
13
- "model_checkpoint": "ssl_aasist_epoch_7.pth"},
14
- "AASIST": {"eer_threshold": 1.8018419742584229,
15
- "data_process_func": "process_assist_input",
16
- "note": "This model is trained on ASVSpoof 2024 training data.",
17
- "model_class":"AASIST_Model",
18
- "model_checkpoint": "orig_aasist_epoch_1.pth"}
19
- }
20
-
21
- model = MOD.Model(None, "cpu")
22
- model.load_state_dict(torch.load("ssl_aasist_epoch_7.pth", map_location="cpu"))
23
- model.eval()
24
- loaded_model = "SSL-AASIST (Trained on ASV-Spoof5)"
25
-
26
- manpulate_config, manipulate_model = get_config_and_model()
27
-
28
- def process(file, type):
29
- global model
30
- global loaded_model
31
- inp = getattr(PD, model_master[type]["data_process_func"])(file)
32
- if not loaded_model == type:
33
- model = getattr(MOD, model_master[type]["model_class"])(None, "cpu")
34
- model.load_state_dict(torch.load(model_master[type]["model_checkpoint"], map_location="cpu"))
35
- model.eval()
36
- loaded_model = type
37
-
38
- op = model(inp).detach().squeeze()[1].item()
39
-
40
- response_text = "Decision score: {} \nDecision threshold: {} \nNotes: 1. Any score below threshold is indicative of fake. \n2. {} ".format(
41
- str(op), str(model_master[type]["eer_threshold"]), model_master[type]["note"])
42
- return response_text
43
-
44
-
45
- demo = gr.Blocks()
46
- file_proc = gr.Interface(
47
- fn=process,
48
- inputs=[
49
- gr.Audio(sources=["upload"], label="Audio file", type="filepath"),
50
- gr.Radio(["SSL-AASIST (Trained on ASV-Spoof5)", "AASIST"], label="Select Model", type="value"),
51
- ],
52
- outputs="text",
53
- title="Find the Fake: Analyze 'Real' or 'Fake'.",
54
- description=(
55
- "Analyze fake or real with a click of a button. Upload a .wav or .flac file."
56
- ),
57
- examples=[
58
- ["./bonafide.flac", "SSL-AASIST (Trained on ASV-Spoof5)"],
59
- ["./fake.flac", "SSL-AASIST (Trained on ASV-Spoof5)"],
60
- ["./bonafide.flac", "AASIST"],
61
- ["./fake.flac", "AASIST"],
62
- ],
63
- cache_examples=True,
64
- allow_flagging="never",
65
- )
66
- #####################################################################################
67
- # For ASR interface
68
- pipe = pipeline(
69
- task="automatic-speech-recognition",
70
- model="openai/whisper-large-v3",
71
- chunk_length_s=30,
72
- device="cpu",
73
- )
74
-
75
- def transcribe(inputs):
76
- if inputs is None:
77
- raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
78
-
79
- op = pipe(inputs, batch_size=8, generate_kwargs={"task": "transcribe"}, return_timestamps=False, return_language=True)
80
- lang = op["chunks"][0]["language"]
81
- text = op["text"]
82
-
83
- return lang, text
84
-
85
- transcribe_proc = gr.Interface(
86
- fn = transcribe,
87
- inputs = [
88
- gr.Audio(type="filepath", label="Speech file (<30s)", max_length=30, sources=["microphone", "upload"], show_download_button=True)
89
- ],
90
- outputs=[
91
- gr.Text(label="Predicted Language", info="Language identification is performed automatically."),
92
- gr.Text(label="Predicted transcription", info="Best hypothesis."),
93
- ],
94
- title="Transcribe Anything.",
95
- description=(
96
- "Automatactic language identification and transcription service by Whisper Large V3. Upload a .wav or .flac file."
97
- ),
98
- allow_flagging="never"
99
- )
100
-
101
- #############################################################################################
102
- #For manipulation detection interface
103
-
104
- def detect_manipulation(inputs):
105
- global manipulate_model
106
- global manpulate_config
107
- out = infere(manipulate_model, inputs, manpulate_config)
108
- out = out.tolist()
109
- return str(out)
110
-
111
- manipulate_proc = gr.Interface(
112
- fn = detect_manipulation,
113
- inputs=[
114
- gr.Audio(type="filepath", label="Speech file (<30s)", max_length=30, sources=["microphone", "upload"], show_download_button=True)
115
- ],
116
- outputs=[
117
- gr.Text(label="Predicted manipulations", info="Manipulation detection is performed automatically."),
118
- ],
119
- title="Find the manipulated segments",
120
- description=(
121
- "Automatactic manipulation detection service. Upload a audio file."
122
- ),
123
- allow_flagging="never"
124
- )
125
-
126
- with demo:
127
- gr.TabbedInterface([file_proc, transcribe_proc, manipulate_proc], ["Analyze Audio File", "Transcribe Audio File", "Manipulation Detection"])
128
- demo.queue(max_size=10)
129
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models.py CHANGED
@@ -6,12 +6,13 @@ import torch.nn as nn
6
  from torch import Tensor
7
  from typing import Union
8
  import torch.nn.functional as F
 
9
 
10
  class SSLModel(nn.Module):
11
  def __init__(self, device):
12
  super(SSLModel, self).__init__()
13
 
14
- cp_path = 'xlsr2_300m.pt' # Change the pre-trained XLSR model path.
15
  model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
16
  self.model = model[0]
17
  self.device = device
 
6
  from torch import Tensor
7
  from typing import Union
8
  import torch.nn.functional as F
9
+ from huggingface_hub import hf_hub_download
10
 
11
  class SSLModel(nn.Module):
12
  def __init__(self, device):
13
  super(SSLModel, self).__init__()
14
 
15
+ cp_path = hf_hub_download("arnabdas8901/aasist-trained-asvspoof2024", filename='xlsr2_300m.pt') #'xlsr2_300m.pt' # Change the pre-trained XLSR model path.
16
  model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
17
  self.model = model[0]
18
  self.device = device