phamngoctukts commited on
Commit
06ecff4
·
verified ·
1 Parent(s): 4ba7625

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -156
app.py CHANGED
@@ -1,157 +1,157 @@
1
- import speech_recognition as sr
2
- import ollama
3
- from gtts import gTTS
4
- import gradio as gr
5
- from io import BytesIO
6
- import numpy as np
7
- from dataclasses import dataclass, field
8
- import time
9
- import traceback
10
- from pydub import AudioSegment
11
- import librosa
12
- from utils.vad import get_speech_timestamps, collect_chunks, VadOptions
13
- from transformers import pipeline
14
- from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
15
- import torch
16
- from huggingface_hub import login
17
- tk = "hf" + "_" + "qTOSlDtDtBgJbofv" + "MglsjjhQqbRAYRYnXy"
18
- login(tk)
19
-
20
- r = sr.Recognizer()
21
-
22
- tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
23
- model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B")
24
- text2text = pipeline("text-generation", model=model, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto", use_auth_token=True)
25
-
26
- @dataclass
27
- class AppState:
28
- stream: np.ndarray | None = None
29
- sampling_rate: int = 0
30
- pause_detected: bool = False
31
- started_talking: bool = False
32
- stopped: bool = False
33
- conversation: list = field(default_factory=list)
34
-
35
- def run_vad(ori_audio, sr):
36
- _st = time.time()
37
- try:
38
- audio = ori_audio
39
- audio = audio.astype(np.float32) / 32768.0
40
- sampling_rate = 16000
41
- if sr != sampling_rate:
42
- audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate)
43
- vad_parameters = {}
44
- vad_parameters = VadOptions(**vad_parameters)
45
- speech_chunks = get_speech_timestamps(audio, vad_parameters)
46
- audio = collect_chunks(audio, speech_chunks)
47
- duration_after_vad = audio.shape[0] / sampling_rate
48
- if sr != sampling_rate:
49
- # resample to original sampling rate
50
- vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr)
51
- else:
52
- vad_audio = audio
53
- vad_audio = np.round(vad_audio * 32768.0).astype(np.int16)
54
- vad_audio_bytes = vad_audio.tobytes()
55
- return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4)
56
- except Exception as e:
57
- msg = f"[asr vad error] audio_len: {len(ori_audio)/(sr*2):.3f} s, trace: {traceback.format_exc()}"
58
- print(msg)
59
- return -1, ori_audio, round(time.time() - _st, 4)
60
-
61
- def determine_pause(audio: np.ndarray, sampling_rate: int, state: AppState) -> bool:
62
- """Take in the stream, determine if a pause happened"""
63
- temp_audio = audio
64
- dur_vad, _, time_vad = run_vad(temp_audio, sampling_rate)
65
- duration = len(audio) / sampling_rate
66
- if dur_vad > 0.5 and not state.started_talking:
67
- print("started talking")
68
- state.started_talking = True
69
- return False
70
- print(f"duration_after_vad: {dur_vad:.3f} s, time_vad: {time_vad:.3f} s")
71
- return (duration - dur_vad) > 1
72
-
73
- def process_audio(audio:tuple, state:AppState):
74
- if state.stream is None:
75
- state.stream = audio[1]
76
- state.sampling_rate = audio[0]
77
- else:
78
- state.stream = np.concatenate((state.stream, audio[1]))
79
- pause_detected = determine_pause(state.stream, state.sampling_rate, state)
80
- state.pause_detected = pause_detected
81
- if state.pause_detected and state.started_talking:
82
- return gr.Audio(recording=False), state
83
- return None, state
84
-
85
- def response(state:AppState):
86
- if not state.pause_detected and not state.started_talking:
87
- return None, AppState()
88
- audio_buffer = BytesIO()
89
- segment = AudioSegment(
90
- state.stream.tobytes(),
91
- frame_rate=state.sampling_rate,
92
- sample_width=state.stream.dtype.itemsize,
93
- channels=(1 if len(state.stream.shape) == 1 else state.stream.shape[1]),
94
- )
95
- segment.export(audio_buffer, format="wav")
96
- textin = ""
97
- with sr.AudioFile(audio_buffer) as source:
98
- audio_data=r.record(source)
99
- try:
100
- textin=r.recognize_google(audio_data,language='vi')
101
- except:
102
- textin = ""
103
- state.conversation.append({"role": "user", "content": "Bạn: " + textin})
104
- if textin != "":
105
- print("Đang nghĩ...")
106
- textout=str(text2text(textin))
107
- textout = textout.replace('*','')
108
- state.conversation.append({"role": "user", "content": "Trợ lý: " + textout})
109
- if textout != "":
110
- print("Đang đọc...")
111
- mp3 = gTTS(textout,tld='com.vn',lang='vi',slow=False)
112
- mp3_fp = BytesIO()
113
- mp3.write_to_fp(mp3_fp)
114
- srr=mp3_fp.getvalue()
115
- mp3_fp.close()
116
- #yield srr, state
117
- yield srr, AppState(conversation=state.conversation)
118
-
119
- def start_recording_user(state: AppState):
120
- if not state.stopped:
121
- return gr.Audio(recording=True)
122
-
123
- title = "vietnamese by tuphamkts"
124
- description = "A vietnamese text-to-speech demo."
125
-
126
- with gr.Blocks() as demo:
127
- with gr.Row():
128
- with gr.Column():
129
- input_audio = gr.Audio(label="Nói cho tôi nghe nào", sources="microphone", type="numpy")
130
- with gr.Column():
131
- chatbot = gr.Chatbot(label="Nội dung trò chuyện", type="messages")
132
- output_audio = gr.Audio(label="Trợ lý", autoplay=True)
133
- state = gr.State(value=AppState())
134
-
135
- stream = input_audio.stream(
136
- process_audio,
137
- [input_audio, state],
138
- [input_audio, state],
139
- stream_every=0.50,
140
- time_limit=30,
141
- )
142
- respond = input_audio.stop_recording(
143
- response,
144
- [state],
145
- [output_audio, state],
146
- )
147
- respond.then(lambda s: s.conversation, [state], [chatbot])
148
-
149
- restart = output_audio.stop(
150
- start_recording_user,
151
- [state],
152
- [input_audio],
153
- )
154
- cancel = gr.Button("Stop Conversation", variant="stop")
155
- cancel.click(lambda: (AppState(stopped=True), gr.Audio(recording=False)), None,
156
- [state, input_audio], cancels=[respond, restart])
157
  demo.launch()
 
1
+ import speech_recognition as sr
2
+ import ollama
3
+ from gtts import gTTS
4
+ import gradio as gr
5
+ from io import BytesIO
6
+ import numpy as np
7
+ from dataclasses import dataclass, field
8
+ import time
9
+ import traceback
10
+ from pydub import AudioSegment
11
+ import librosa
12
+ from utils.vad import get_speech_timestamps, collect_chunks, VadOptions
13
+ from transformers import pipeline
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
15
+ import torch
16
+ from huggingface_hub import login
17
+ tk = "hf" + "_" + "qTOSlDtDtBgJbofv" + "MglsjjhQqbRAYRYnXy"
18
+ login(tk)
19
+
20
+ r = sr.Recognizer()
21
+
22
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
23
+ model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
24
+ text2text = pipeline("text-generation", model=model, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto", use_auth_token=True)
25
+
26
+ @dataclass
27
+ class AppState:
28
+ stream: np.ndarray | None = None
29
+ sampling_rate: int = 0
30
+ pause_detected: bool = False
31
+ started_talking: bool = False
32
+ stopped: bool = False
33
+ conversation: list = field(default_factory=list)
34
+
35
+ def run_vad(ori_audio, sr):
36
+ _st = time.time()
37
+ try:
38
+ audio = ori_audio
39
+ audio = audio.astype(np.float32) / 32768.0
40
+ sampling_rate = 16000
41
+ if sr != sampling_rate:
42
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate)
43
+ vad_parameters = {}
44
+ vad_parameters = VadOptions(**vad_parameters)
45
+ speech_chunks = get_speech_timestamps(audio, vad_parameters)
46
+ audio = collect_chunks(audio, speech_chunks)
47
+ duration_after_vad = audio.shape[0] / sampling_rate
48
+ if sr != sampling_rate:
49
+ # resample to original sampling rate
50
+ vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr)
51
+ else:
52
+ vad_audio = audio
53
+ vad_audio = np.round(vad_audio * 32768.0).astype(np.int16)
54
+ vad_audio_bytes = vad_audio.tobytes()
55
+ return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4)
56
+ except Exception as e:
57
+ msg = f"[asr vad error] audio_len: {len(ori_audio)/(sr*2):.3f} s, trace: {traceback.format_exc()}"
58
+ print(msg)
59
+ return -1, ori_audio, round(time.time() - _st, 4)
60
+
61
+ def determine_pause(audio: np.ndarray, sampling_rate: int, state: AppState) -> bool:
62
+ """Take in the stream, determine if a pause happened"""
63
+ temp_audio = audio
64
+ dur_vad, _, time_vad = run_vad(temp_audio, sampling_rate)
65
+ duration = len(audio) / sampling_rate
66
+ if dur_vad > 0.5 and not state.started_talking:
67
+ print("started talking")
68
+ state.started_talking = True
69
+ return False
70
+ print(f"duration_after_vad: {dur_vad:.3f} s, time_vad: {time_vad:.3f} s")
71
+ return (duration - dur_vad) > 1
72
+
73
+ def process_audio(audio:tuple, state:AppState):
74
+ if state.stream is None:
75
+ state.stream = audio[1]
76
+ state.sampling_rate = audio[0]
77
+ else:
78
+ state.stream = np.concatenate((state.stream, audio[1]))
79
+ pause_detected = determine_pause(state.stream, state.sampling_rate, state)
80
+ state.pause_detected = pause_detected
81
+ if state.pause_detected and state.started_talking:
82
+ return gr.Audio(recording=False), state
83
+ return None, state
84
+
85
+ def response(state:AppState):
86
+ if not state.pause_detected and not state.started_talking:
87
+ return None, AppState()
88
+ audio_buffer = BytesIO()
89
+ segment = AudioSegment(
90
+ state.stream.tobytes(),
91
+ frame_rate=state.sampling_rate,
92
+ sample_width=state.stream.dtype.itemsize,
93
+ channels=(1 if len(state.stream.shape) == 1 else state.stream.shape[1]),
94
+ )
95
+ segment.export(audio_buffer, format="wav")
96
+ textin = ""
97
+ with sr.AudioFile(audio_buffer) as source:
98
+ audio_data=r.record(source)
99
+ try:
100
+ textin=r.recognize_google(audio_data,language='vi')
101
+ except:
102
+ textin = ""
103
+ state.conversation.append({"role": "user", "content": "Bạn: " + textin})
104
+ if textin != "":
105
+ print("Đang nghĩ...")
106
+ textout=str(text2text(textin))
107
+ textout = textout.replace('*','')
108
+ state.conversation.append({"role": "user", "content": "Trợ lý: " + textout})
109
+ if textout != "":
110
+ print("Đang đọc...")
111
+ mp3 = gTTS(textout,tld='com.vn',lang='vi',slow=False)
112
+ mp3_fp = BytesIO()
113
+ mp3.write_to_fp(mp3_fp)
114
+ srr=mp3_fp.getvalue()
115
+ mp3_fp.close()
116
+ #yield srr, state
117
+ yield srr, AppState(conversation=state.conversation)
118
+
119
+ def start_recording_user(state: AppState):
120
+ if not state.stopped:
121
+ return gr.Audio(recording=True)
122
+
123
+ title = "vietnamese by tuphamkts"
124
+ description = "A vietnamese text-to-speech demo."
125
+
126
+ with gr.Blocks() as demo:
127
+ with gr.Row():
128
+ with gr.Column():
129
+ input_audio = gr.Audio(label="Nói cho tôi nghe nào", sources="microphone", type="numpy")
130
+ with gr.Column():
131
+ chatbot = gr.Chatbot(label="Nội dung trò chuyện", type="messages")
132
+ output_audio = gr.Audio(label="Trợ lý", autoplay=True)
133
+ state = gr.State(value=AppState())
134
+
135
+ stream = input_audio.stream(
136
+ process_audio,
137
+ [input_audio, state],
138
+ [input_audio, state],
139
+ stream_every=0.50,
140
+ time_limit=30,
141
+ )
142
+ respond = input_audio.stop_recording(
143
+ response,
144
+ [state],
145
+ [output_audio, state],
146
+ )
147
+ respond.then(lambda s: s.conversation, [state], [chatbot])
148
+
149
+ restart = output_audio.stop(
150
+ start_recording_user,
151
+ [state],
152
+ [input_audio],
153
+ )
154
+ cancel = gr.Button("Stop Conversation", variant="stop")
155
+ cancel.click(lambda: (AppState(stopped=True), gr.Audio(recording=False)), None,
156
+ [state, input_audio], cancels=[respond, restart])
157
  demo.launch()