Siddhant commited on
Commit
3200ea6
1 Parent(s): 862e4f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -58
app.py CHANGED
@@ -1,63 +1,105 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
 
 
17
  ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
 
62
  if __name__ == "__main__":
63
- demo.launch()
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ import torch
3
+
4
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
5
+
6
+ classifier = pipeline(
7
+ "audio-classification", model="MIT/ast-finetuned-speech-commands-v2", device=device
8
+ )
9
+
10
+ from transformers.pipelines.audio_utils import ffmpeg_microphone_live
11
+
12
+
13
+ def launch_fn(
14
+ wake_word="marvin",
15
+ prob_threshold=0.5,
16
+ chunk_length_s=2.0,
17
+ stream_chunk_s=0.25,
18
+ debug=False,
19
  ):
20
+ if wake_word not in classifier.model.config.label2id.keys():
21
+ raise ValueError(
22
+ f"Wake word {wake_word} not in set of valid class labels, pick a wake word in the set {classifier.model.config.label2id.keys()}."
23
+ )
24
+
25
+ sampling_rate = classifier.feature_extractor.sampling_rate
26
+
27
+ mic = ffmpeg_microphone_live(
28
+ sampling_rate=sampling_rate,
29
+ chunk_length_s=chunk_length_s,
30
+ stream_chunk_s=stream_chunk_s,
31
+ )
32
+
33
+ print("Listening for wake word...")
34
+ for prediction in classifier(mic):
35
+ prediction = prediction[0]
36
+ if debug:
37
+ print(prediction)
38
+ if prediction["label"] == wake_word:
39
+ if prediction["score"] > prob_threshold:
40
+ return True
41
+
42
+ transcriber = pipeline(
43
+ "automatic-speech-recognition", model="openai/whisper-base.en", device=device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  )
45
+ import sys
46
+
47
+
48
+ def transcribe(chunk_length_s=5.0, stream_chunk_s=1.0):
49
+ sampling_rate = transcriber.feature_extractor.sampling_rate
50
+
51
+ mic = ffmpeg_microphone_live(
52
+ sampling_rate=sampling_rate,
53
+ chunk_length_s=chunk_length_s,
54
+ stream_chunk_s=stream_chunk_s,
55
+ )
56
+
57
+ print("Start speaking...")
58
+ for item in transcriber(mic, generate_kwargs={"max_new_tokens": 128}):
59
+ sys.stdout.write("\033[K")
60
+ print(item["text"], end="\r")
61
+ if not item["partial"][0]:
62
+ break
63
+
64
+ return item["text"]
65
+
66
+ from huggingface_hub import HfFolder
67
+ import requests
68
+
69
+
70
+ def query(text, model_id="tiiuae/falcon-7b-instruct"):
71
+ api_url = f"https://api-inference.huggingface.co/models/{model_id}"
72
+ headers = {"Authorization": f"Bearer {HfFolder().get_token()}"}
73
+ payload = {"inputs": text}
74
+
75
+ print(f"Querying...: {text}")
76
+ response = requests.post(api_url, headers=headers, json=payload)
77
+ return response.json()[0]["generated_text"][len(text) + 1 :]
78
+
79
+ from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
80
+
81
+ processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
82
+
83
+ model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
84
+ vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
85
+
86
+ from datasets import load_dataset
87
+
88
+ embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
89
+ speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
90
+
91
+ def synthesise(text):
92
+ inputs = processor(text=text, return_tensors="pt")
93
+ speech = model.generate_speech(
94
+ inputs["input_ids"].to(device), speaker_embeddings.to(device), vocoder=vocoder
95
+ )
96
+ return speech.cpu()
97
 
98
 
99
  if __name__ == "__main__":
100
+ launch_fn()
101
+ transcription = transcribe()
102
+ response = query(transcription)
103
+ audio = synthesise(response)
104
+
105
+ Audio(audio, rate=16000, autoplay=True)