danielwm994 commited on
Commit
f384cdf
·
verified ·
1 Parent(s): 6ce5e8f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -174
app.py CHANGED
@@ -1,179 +1,18 @@
1
- import os
2
- import random
3
- from glob import glob
4
- from typing import List, Optional, Union, Dict
5
-
6
- import tqdm
7
  import torch
8
- import torchaudio
9
- import numpy as np
10
- import pandas as pd
11
- from torch import nn
12
- from torch.utils.data import DataLoader
13
- from torch.nn import functional as F
14
- from transformers import (
15
- AutoFeatureExtractor,
16
- AutoModelForAudioClassification,
17
- Wav2Vec2Processor
18
- )
19
-
20
- class CustomDataset(torch.utils.data.Dataset):
21
- def __init__(
22
- self,
23
- dataset: List,
24
- basedir: Optional[str] = None,
25
- sampling_rate: int = 16000,
26
- max_audio_len: int = 5,
27
- ):
28
- self.dataset = dataset
29
- self.basedir = basedir
30
-
31
- self.sampling_rate = sampling_rate
32
- self.max_audio_len = max_audio_len
33
-
34
- def __len__(self):
35
- """
36
- Return the length of the dataset
37
- """
38
- return len(self.dataset)
39
-
40
- def __getitem__(self, index):
41
- if self.basedir is None:
42
- filepath = self.dataset[index]
43
- else:
44
- filepath = os.path.join(self.basedir, self.dataset[index])
45
-
46
- speech_array, sr = torchaudio.load(filepath)
47
-
48
- if speech_array.shape[0] > 1:
49
- speech_array = torch.mean(speech_array, dim=0, keepdim=True)
50
-
51
- if sr != self.sampling_rate:
52
- transform = torchaudio.transforms.Resample(sr, self.sampling_rate)
53
- speech_array = transform(speech_array)
54
- sr = self.sampling_rate
55
-
56
- len_audio = speech_array.shape[1]
57
-
58
- # Pad or truncate the audio to match the desired length
59
- if len_audio < self.max_audio_len * self.sampling_rate:
60
- # Pad the audio if it's shorter than the desired length
61
- padding = torch.zeros(1, self.max_audio_len * self.sampling_rate - len_audio)
62
- speech_array = torch.cat([speech_array, padding], dim=1)
63
- else:
64
- # Truncate the audio if it's longer than the desired length
65
- speech_array = speech_array[:, :self.max_audio_len * self.sampling_rate]
66
-
67
- speech_array = speech_array.squeeze().numpy()
68
-
69
- return {"input_values": speech_array, "attention_mask": None}
70
-
71
-
72
- class CollateFunc:
73
- def __init__(
74
- self,
75
- processor: Wav2Vec2Processor,
76
- padding: Union[bool, str] = True,
77
- pad_to_multiple_of: Optional[int] = None,
78
- return_attention_mask: bool = True,
79
- sampling_rate: int = 16000,
80
- max_length: Optional[int] = None,
81
- ):
82
- self.sampling_rate = sampling_rate
83
- self.processor = processor
84
- self.padding = padding
85
- self.pad_to_multiple_of = pad_to_multiple_of
86
- self.return_attention_mask = return_attention_mask
87
- self.max_length = max_length
88
-
89
- def __call__(self, batch: List[Dict[str, np.ndarray]]):
90
- # Extract input_values from the batch
91
- input_values = [item["input_values"] for item in batch]
92
-
93
- batch = self.processor(
94
- input_values,
95
- sampling_rate=self.sampling_rate,
96
- return_tensors="pt",
97
- padding=self.padding,
98
- max_length=self.max_length,
99
- pad_to_multiple_of=self.pad_to_multiple_of,
100
- return_attention_mask=self.return_attention_mask
101
- )
102
-
103
- return {
104
- "input_values": batch.input_values,
105
- "attention_mask": batch.attention_mask if self.return_attention_mask else None
106
- }
107
-
108
-
109
- def predict(test_dataloader, model, device: torch.device):
110
- """
111
- Predict the class of the audio
112
- """
113
- model.to(device)
114
- model.eval()
115
- preds = []
116
-
117
- with torch.no_grad():
118
- for batch in tqdm.tqdm(test_dataloader):
119
- input_values, attention_mask = batch['input_values'].to(device), batch['attention_mask'].to(device)
120
-
121
- logits = model(input_values, attention_mask=attention_mask).logits
122
- scores = F.softmax(logits, dim=-1)
123
-
124
- pred = torch.argmax(scores, dim=1).cpu().detach().numpy()
125
-
126
- preds.extend(pred)
127
-
128
- return preds
129
-
130
-
131
- def get_gender(model_name_or_path: str, audio_paths: List[str], label2id: Dict, id2label: Dict, device: torch.device):
132
- num_labels = 2
133
-
134
- feature_extractor = AutoFeatureExtractor.from_pretrained(model_name_or_path)
135
- model = AutoModelForAudioClassification.from_pretrained(
136
- pretrained_model_name_or_path=model_name_or_path,
137
- num_labels=num_labels,
138
- label2id=label2id,
139
- id2label=id2label,
140
- )
141
-
142
- test_dataset = CustomDataset(audio_paths, max_audio_len=5) # for 5-second audio
143
-
144
- data_collator = CollateFunc(
145
- processor=feature_extractor,
146
- padding=True,
147
- sampling_rate=16000,
148
- )
149
-
150
- test_dataloader = DataLoader(
151
- dataset=test_dataset,
152
- batch_size=16,
153
- collate_fn=data_collator,
154
- shuffle=False,
155
- num_workers=2
156
- )
157
-
158
- preds = predict(test_dataloader=test_dataloader, model=model, device=device)
159
-
160
- return preds
161
-
162
- model_name_or_path = "alefiury/wav2vec2-large-xlsr-53-gender-recognition-librispeech"
163
-
164
- audio_paths = [] # Must be a list with absolute paths of the audios that will be used in inference
165
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
166
 
167
- label2id = {
168
- "female": 0,
169
- "male": 1
170
- }
171
 
172
- id2label = {
173
- 0: "female",
174
- 1: "male"
175
- }
 
 
 
176
 
177
- num_labels = 2
178
 
179
- preds = get_gender(model_name_or_path, audio_paths, label2id, id2label, device)
 
 
 
1
+ import re
2
+ from gender_prediction import get_gender
3
+ import gradio as gr
 
 
 
4
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
 
 
 
 
6
 
7
+ def app(voice):
8
+ model_name_or_path = "alefiury/wav2vec2-large-xlsr-53-gender-recognition-librispeech"
9
+ audio_paths = [voice]
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ predicted_label = get_gender(model_name_or_path, audio_paths, device)
12
+ gender=re.search("female|male",predicted_label)
13
+ return gender.string
14
 
 
15
 
16
+ interface=gr.Interface(fn=app,inputs=[gr.components.Audio(type="filepath",sources="upload",label="upload voice")],
17
+ outputs=[gr.components.Textbox(label="your result")])
18
+ interface.launch(debug=True)