prashant-garg commited on
Commit
3b0b181
·
1 Parent(s): ab02fe1

gender detection app

Browse files
Files changed (3) hide show
  1. app.py +120 -2
  2. requirements.txt +54 -0
  3. runtime.txt +1 -0
app.py CHANGED
@@ -1,4 +1,122 @@
 
 
 
 
 
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
1
+ """
2
+ Streamlit application for real-time gender detection from audio input.
3
+ Uses wav2vec2 model to analyze voice and predict speaker gender.
4
+ """
5
+
6
  import streamlit as st
7
+ import pyaudio
8
+ import numpy as np
9
+ import torch
10
+ from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
11
+ import logging
12
+
13
+ # Configure logging
14
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
15
+
16
+ # Define audio stream parameters
17
+ FORMAT = pyaudio.paInt16 # 16-bit resolution
18
+ CHANNELS = 1 # Mono audio
19
+ RATE = 16000 # 16kHz sampling rate
20
+ CHUNK = 1024 # Number of frames per buffer
21
+
22
+ @st.cache_resource
23
+ def load_model():
24
+ """
25
+ Load the wav2vec2 model and feature extractor for gender recognition.
26
+
27
+ Returns:
28
+ tuple: A tuple containing the feature extractor and the model.
29
+ """
30
+ model_path = "alefiury/wav2vec2-large-xlsr-53-gender-recognition-librispeech"
31
+ # model_path = "./local-model"
32
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_path)
33
+ model = AutoModelForAudioClassification.from_pretrained(model_path)
34
+ model.eval()
35
+ logging.info("Model loaded successfully.")
36
+ return feature_extractor, model
37
+
38
+ st.title("Gender Detection")
39
+
40
+ # Initialize session state
41
+ if 'listening' not in st.session_state:
42
+ st.session_state['listening'] = False
43
+ if 'prediction' not in st.session_state:
44
+ st.session_state['prediction'] = ""
45
+
46
+ # Function to stop listening
47
+ def stop_listening():
48
+ """Stop the audio stream and update session state to stop listening."""
49
+ if 'stream' in st.session_state:
50
+ logging.info("Stopping stream")
51
+ st.session_state['stream'].stop_stream()
52
+ st.session_state['stream'].close()
53
+ if 'audio' in st.session_state:
54
+ logging.info("Stopping audio")
55
+ st.session_state['audio'].terminate()
56
+ st.session_state['listening'] = False
57
+ st.session_state['prediction'] = "Stopped listening, click 'Start Listening' to start again."
58
+ st.rerun()
59
+
60
+ def start_listening():
61
+ """Start the audio stream and continuously process audio for gender detection."""
62
+ placeholder = st.empty()
63
+ try:
64
+ placeholder.write("Loading model...")
65
+ feature_extractor, model = load_model()
66
+ audio = pyaudio.PyAudio()
67
+ stream = audio.open(format=FORMAT,
68
+ channels=CHANNELS,
69
+ rate=RATE,
70
+ input=True,
71
+ frames_per_buffer=CHUNK)
72
+
73
+ st.session_state['stream'] = stream
74
+ st.session_state['audio'] = audio
75
+ st.session_state['listening'] = True
76
+ st.session_state['prediction'] = "Listening........................"
77
+ placeholder.write("Listening for audio...")
78
+
79
+ while st.session_state['listening']:
80
+ audio_data = np.array([], dtype=np.float32)
81
+
82
+ for _ in range(int(RATE / CHUNK * 1.5)):
83
+ # Read audio chunk from the stream
84
+ data = stream.read(CHUNK, exception_on_overflow=False)
85
+
86
+ # Convert byte data to numpy array and normalize
87
+ chunk_data = np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
88
+ audio_data = np.concatenate((audio_data, chunk_data))
89
+
90
+ # Check if there is significant sound
91
+ if np.max(np.abs(audio_data)) > 0.05: # Threshold for detecting sound
92
+ # Process the audio data
93
+ inputs = feature_extractor(audio_data, sampling_rate=RATE, return_tensors="pt", padding=True)
94
+ # Perform inference
95
+ with torch.no_grad():
96
+ logits = model(**inputs).logits
97
+ predicted_ids = torch.argmax(logits, dim=-1)
98
+
99
+ # Map predicted IDs to labels
100
+ predicted_label = model.config.id2label[predicted_ids.item()]
101
+
102
+ if predicted_label != st.session_state['prediction']:
103
+ st.session_state['prediction'] = predicted_label
104
+ # st.write(f"Detected Gender: {predicted_label}")
105
+ placeholder.write(f"Detected Gender: {predicted_label}")
106
+ else:
107
+ st.session_state['prediction'] = "---- No significant sound detected, skipping prediction. ----"
108
+ placeholder.empty()
109
+ placeholder.empty()
110
+ except Exception as e:
111
+ logging.error(f"An error occurred: {e}")
112
+ st.error(f"An error occurred: {e}")
113
+ stop_listening()
114
 
115
+ # Buttons to start and stop listening
116
+ col1, col2 = st.columns(2)
117
+ with col1:
118
+ if st.button("Start Listening"):
119
+ start_listening()
120
+ with col2:
121
+ if st.button("Stop Listening"):
122
+ stop_listening()
requirements.txt ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ altair==5.5.0
2
+ attrs==25.1.0
3
+ blinker==1.9.0
4
+ cachetools==5.5.1
5
+ certifi==2025.1.31
6
+ charset-normalizer==3.4.1
7
+ click==8.1.8
8
+ filelock==3.17.0
9
+ fsspec==2025.2.0
10
+ gitdb==4.0.12
11
+ GitPython==3.1.44
12
+ huggingface-hub==0.28.1
13
+ idna==3.10
14
+ Jinja2==3.1.5
15
+ jsonschema==4.23.0
16
+ jsonschema-specifications==2024.10.1
17
+ markdown-it-py==3.0.0
18
+ MarkupSafe==3.0.2
19
+ mdurl==0.1.2
20
+ mpmath==1.3.0
21
+ narwhals==1.26.0
22
+ networkx==3.4.2
23
+ numpy==2.2.3
24
+ packaging==24.2
25
+ pandas==2.2.3
26
+ pillow==11.1.0
27
+ protobuf==5.29.3
28
+ pyarrow==19.0.0
29
+ PyAudio==0.2.14
30
+ pydeck==0.9.1
31
+ Pygments==2.19.1
32
+ python-dateutil==2.9.0.post0
33
+ pytz==2025.1
34
+ PyYAML==6.0.2
35
+ referencing==0.36.2
36
+ regex==2024.11.6
37
+ requests==2.32.3
38
+ rich==13.9.4
39
+ rpds-py==0.22.3
40
+ safetensors==0.5.2
41
+ six==1.17.0
42
+ smmap==5.0.2
43
+ streamlit==1.42.0
44
+ sympy==1.13.1
45
+ tenacity==9.0.0
46
+ tokenizers==0.21.0
47
+ toml==0.10.2
48
+ torch==2.6.0
49
+ tornado==6.4.2
50
+ tqdm==4.67.1
51
+ transformers==4.48.3
52
+ typing_extensions==4.12.2
53
+ tzdata==2025.1
54
+ urllib3==2.3.0
runtime.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python-3.10