freddyaboulton HF staff commited on
Commit
2ec6a8f
·
verified ·
1 Parent(s): 6f17857

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. README_gradio.md +15 -0
  2. app.py +151 -0
  3. index.html +417 -0
  4. requirements.txt +3 -0
README_gradio.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Hello Computer (Gradio)
3
+ emoji: 💻
4
+ colorFrom: purple
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 5.16.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: Say computer (Gradio)
12
+ tags: [webrtc, websocket, gradio, secret|TWILIO_ACCOUNT_SID, secret|TWILIO_AUTH_TOKEN, secret|SAMBANOVA_API_KEY]
13
+ ---
14
+
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import os
4
+ from pathlib import Path
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import openai
9
+ from dotenv import load_dotenv
10
+ from fastapi import FastAPI
11
+ from fastapi.responses import HTMLResponse, StreamingResponse
12
+ from fastrtc import (
13
+ AdditionalOutputs,
14
+ ReplyOnStopWords,
15
+ Stream,
16
+ WebRTCError,
17
+ get_twilio_turn_credentials,
18
+ stt,
19
+ )
20
+ from gradio.utils import get_space
21
+ from pydantic import BaseModel
22
+
23
+ load_dotenv()
24
+
25
+ curr_dir = Path(__file__).parent
26
+
27
+
28
+ client = openai.OpenAI(
29
+ api_key=os.environ.get("SAMBANOVA_API_KEY"),
30
+ base_url="https://api.sambanova.ai/v1",
31
+ )
32
+
33
+
34
+ def response(
35
+ audio: tuple[int, np.ndarray],
36
+ gradio_chatbot: list[dict] | None = None,
37
+ conversation_state: list[dict] | None = None,
38
+ ):
39
+ gradio_chatbot = gradio_chatbot or []
40
+ conversation_state = conversation_state or []
41
+
42
+ text = stt(audio)
43
+ sample_rate, array = audio
44
+ gradio_chatbot.append(
45
+ {"role": "user", "content": gr.Audio((sample_rate, array.squeeze()))}
46
+ )
47
+ yield AdditionalOutputs(gradio_chatbot, conversation_state)
48
+
49
+ conversation_state.append({"role": "user", "content": text})
50
+
51
+ try:
52
+ request = client.chat.completions.create(
53
+ model="Meta-Llama-3.2-3B-Instruct",
54
+ messages=conversation_state,
55
+ temperature=0.1,
56
+ top_p=0.1,
57
+ )
58
+ response = {"role": "assistant", "content": request.choices[0].message.content}
59
+
60
+ except Exception:
61
+ import traceback
62
+
63
+ traceback.print_exc()
64
+ raise WebRTCError(traceback.format_exc())
65
+
66
+ conversation_state.append(response)
67
+ gradio_chatbot.append(response)
68
+
69
+ yield AdditionalOutputs(gradio_chatbot, conversation_state)
70
+
71
+
72
+ chatbot = gr.Chatbot(type="messages", value=[])
73
+ state = gr.State(value=[])
74
+ stream = Stream(
75
+ ReplyOnStopWords(
76
+ response, # type: ignore
77
+ stop_words=["computer"],
78
+ input_sample_rate=16000,
79
+ ),
80
+ mode="send",
81
+ modality="audio",
82
+ additional_inputs=[chatbot, state],
83
+ additional_outputs=[chatbot, state],
84
+ additional_outputs_handler=lambda *a: (a[2], a[3]),
85
+ concurrency_limit=20 if get_space() else None,
86
+ )
87
+
88
+ app = FastAPI()
89
+ stream.mount(app)
90
+
91
+
92
+ class Message(BaseModel):
93
+ role: str
94
+ content: str
95
+
96
+
97
+ class InputData(BaseModel):
98
+ webrtc_id: str
99
+ chatbot: list[Message]
100
+ state: list[Message]
101
+
102
+
103
+ @app.get("/")
104
+ async def _():
105
+ rtc_config = get_twilio_turn_credentials() if get_space() else None
106
+ html_content = (curr_dir / "index.html").read_text()
107
+ html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
108
+ return HTMLResponse(content=html_content)
109
+
110
+
111
+ @app.post("/input_hook")
112
+ async def _(data: InputData):
113
+ body = data.model_dump()
114
+ stream.set_input(data.webrtc_id, body["chatbot"], body["state"])
115
+
116
+
117
+ def audio_to_base64(file_path):
118
+ audio_format = "wav"
119
+ with open(file_path, "rb") as audio_file:
120
+ encoded_audio = base64.b64encode(audio_file.read()).decode("utf-8")
121
+ return f"data:audio/{audio_format};base64,{encoded_audio}"
122
+
123
+
124
+ @app.get("/outputs")
125
+ async def _(webrtc_id: str):
126
+ async def output_stream():
127
+ async for output in stream.output_stream(webrtc_id):
128
+ chatbot = output.args[0]
129
+ state = output.args[1]
130
+ data = {
131
+ "message": state[-1],
132
+ "audio": audio_to_base64(chatbot[-1]["content"].value["path"])
133
+ if chatbot[-1]["role"] == "user"
134
+ else None,
135
+ }
136
+ yield f"event: output\ndata: {json.dumps(data)}\n\n"
137
+
138
+ return StreamingResponse(output_stream(), media_type="text/event-stream")
139
+
140
+
141
+ if __name__ == "__main__":
142
+ import os
143
+
144
+ if (mode := os.getenv("MODE")) == "UI":
145
+ stream.ui.launch(server_port=7860)
146
+ elif mode == "PHONE":
147
+ raise ValueError("Phone mode not supported")
148
+ else:
149
+ import uvicorn
150
+
151
+ uvicorn.run(app, host="0.0.0.0", port=7860)
index.html ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <meta charset="UTF-8">
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
+ <title>Hello Computer 💻</title>
8
+ <style>
9
+ body {
10
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
11
+ background-color: #f8f9fa;
12
+ color: #1a1a1a;
13
+ margin: 0;
14
+ padding: 20px;
15
+ height: 100vh;
16
+ box-sizing: border-box;
17
+ }
18
+
19
+ .container {
20
+ max-width: 800px;
21
+ margin: 0 auto;
22
+ height: calc(100% - 100px);
23
+ }
24
+
25
+ .logo {
26
+ text-align: center;
27
+ margin-bottom: 40px;
28
+ }
29
+
30
+ .chat-container {
31
+ background: white;
32
+ border-radius: 8px;
33
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
34
+ padding: 20px;
35
+ height: 90%;
36
+ box-sizing: border-box;
37
+ display: flex;
38
+ flex-direction: column;
39
+ }
40
+
41
+ .chat-messages {
42
+ flex-grow: 1;
43
+ overflow-y: auto;
44
+ margin-bottom: 20px;
45
+ padding: 10px;
46
+ }
47
+
48
+ .message {
49
+ margin-bottom: 20px;
50
+ padding: 12px;
51
+ border-radius: 8px;
52
+ font-size: 14px;
53
+ line-height: 1.5;
54
+ }
55
+
56
+ .message.user {
57
+ background-color: #e9ecef;
58
+ margin-left: 20%;
59
+ }
60
+
61
+ .message.assistant {
62
+ background-color: #f1f3f5;
63
+ margin-right: 20%;
64
+ }
65
+
66
+ .controls {
67
+ text-align: center;
68
+ margin-top: 20px;
69
+ }
70
+
71
+ button {
72
+ background-color: #0066cc;
73
+ color: white;
74
+ border: none;
75
+ padding: 12px 24px;
76
+ font-family: inherit;
77
+ font-size: 14px;
78
+ cursor: pointer;
79
+ transition: all 0.3s;
80
+ border-radius: 4px;
81
+ font-weight: 500;
82
+ }
83
+
84
+ button:hover {
85
+ background-color: #0052a3;
86
+ }
87
+
88
+ #audio-output {
89
+ display: none;
90
+ }
91
+
92
+ .icon-with-spinner {
93
+ display: flex;
94
+ align-items: center;
95
+ justify-content: center;
96
+ gap: 12px;
97
+ min-width: 180px;
98
+ }
99
+
100
+ .spinner {
101
+ width: 20px;
102
+ height: 20px;
103
+ border: 2px solid #ffffff;
104
+ border-top-color: transparent;
105
+ border-radius: 50%;
106
+ animation: spin 1s linear infinite;
107
+ flex-shrink: 0;
108
+ }
109
+
110
+ @keyframes spin {
111
+ to {
112
+ transform: rotate(360deg);
113
+ }
114
+ }
115
+
116
+ .pulse-container {
117
+ display: flex;
118
+ align-items: center;
119
+ justify-content: center;
120
+ gap: 12px;
121
+ min-width: 180px;
122
+ }
123
+
124
+ .pulse-circle {
125
+ width: 20px;
126
+ height: 20px;
127
+ border-radius: 50%;
128
+ background-color: #ffffff;
129
+ opacity: 0.2;
130
+ flex-shrink: 0;
131
+ transform: translateX(-0%) scale(var(--audio-level, 1));
132
+ transition: transform 0.1s ease;
133
+ }
134
+
135
+ /* Add styles for typing indicator */
136
+ .typing-indicator {
137
+ padding: 8px;
138
+ background-color: #f1f3f5;
139
+ border-radius: 8px;
140
+ margin-bottom: 10px;
141
+ display: none;
142
+ }
143
+
144
+ .dots {
145
+ display: inline-flex;
146
+ gap: 4px;
147
+ }
148
+
149
+ .dot {
150
+ width: 8px;
151
+ height: 8px;
152
+ background-color: #0066cc;
153
+ border-radius: 50%;
154
+ animation: pulse 1.5s infinite;
155
+ opacity: 0.5;
156
+ }
157
+
158
+ .dot:nth-child(2) {
159
+ animation-delay: 0.5s;
160
+ }
161
+
162
+ .dot:nth-child(3) {
163
+ animation-delay: 1s;
164
+ }
165
+
166
+ @keyframes pulse {
167
+
168
+ 0%,
169
+ 100% {
170
+ opacity: 0.5;
171
+ transform: scale(1);
172
+ }
173
+
174
+ 50% {
175
+ opacity: 1;
176
+ transform: scale(1.2);
177
+ }
178
+ }
179
+ </style>
180
+ </head>
181
+
182
+ <body>
183
+ <div class="container">
184
+ <div class="logo">
185
+ <h1>Hello Computer 💻</h1>
186
+ <h2 style="font-size: 1.2em; color: #666; margin-top: 10px;">Say 'Computer' before asking your question</h2>
187
+ </div>
188
+ <div class="chat-container">
189
+ <div class="chat-messages" id="chat-messages"></div>
190
+ <div class="typing-indicator" id="typing-indicator">
191
+ <div class="dots">
192
+ <div class="dot"></div>
193
+ <div class="dot"></div>
194
+ <div class="dot"></div>
195
+ </div>
196
+ </div>
197
+ </div>
198
+ <div class="controls">
199
+ <button id="start-button">Start Conversation</button>
200
+ </div>
201
+ </div>
202
+ <audio id="audio-output"></audio>
203
+
204
+ <script>
205
+ let peerConnection;
206
+ let webrtc_id;
207
+ const startButton = document.getElementById('start-button');
208
+ const chatMessages = document.getElementById('chat-messages');
209
+
210
+ let audioLevel = 0;
211
+ let animationFrame;
212
+ let audioContext, analyser, audioSource;
213
+ let messages = [];
214
+ let eventSource;
215
+
216
+ function updateButtonState() {
217
+ const button = document.getElementById('start-button');
218
+ if (peerConnection && (peerConnection.connectionState === 'connecting' || peerConnection.connectionState === 'new')) {
219
+ button.innerHTML = `
220
+ <div class="icon-with-spinner">
221
+ <div class="spinner"></div>
222
+ <span>Connecting...</span>
223
+ </div>
224
+ `;
225
+ } else if (peerConnection && peerConnection.connectionState === 'connected') {
226
+ button.innerHTML = `
227
+ <div class="pulse-container">
228
+ <div class="pulse-circle"></div>
229
+ <span>Stop Conversation</span>
230
+ </div>
231
+ `;
232
+ } else {
233
+ button.innerHTML = 'Start Conversation';
234
+ }
235
+ }
236
+
237
+ function setupAudioVisualization(stream) {
238
+ audioContext = new (window.AudioContext || window.webkitAudioContext)();
239
+ analyser = audioContext.createAnalyser();
240
+ audioSource = audioContext.createMediaStreamSource(stream);
241
+ audioSource.connect(analyser);
242
+ analyser.fftSize = 64;
243
+ const dataArray = new Uint8Array(analyser.frequencyBinCount);
244
+
245
+ function updateAudioLevel() {
246
+ analyser.getByteFrequencyData(dataArray);
247
+ const average = Array.from(dataArray).reduce((a, b) => a + b, 0) / dataArray.length;
248
+ audioLevel = average / 255;
249
+
250
+ const pulseCircle = document.querySelector('.pulse-circle');
251
+ if (pulseCircle) {
252
+ pulseCircle.style.setProperty('--audio-level', 1 + audioLevel);
253
+ }
254
+
255
+ animationFrame = requestAnimationFrame(updateAudioLevel);
256
+ }
257
+ updateAudioLevel();
258
+ }
259
+
260
+ function handleMessage(event) {
261
+ const eventJson = JSON.parse(event.data);
262
+ const typingIndicator = document.getElementById('typing-indicator');
263
+
264
+ if (eventJson.type === "send_input") {
265
+ fetch('/input_hook', {
266
+ method: 'POST',
267
+ headers: {
268
+ 'Content-Type': 'application/json',
269
+ },
270
+ body: JSON.stringify({
271
+ webrtc_id: webrtc_id,
272
+ chatbot: messages,
273
+ state: messages
274
+ })
275
+ });
276
+ } else if (eventJson.type === "log") {
277
+ if (eventJson.data === "pause_detected") {
278
+ typingIndicator.style.display = 'block';
279
+ chatMessages.scrollTop = chatMessages.scrollHeight;
280
+ } else if (eventJson.data === "response_starting") {
281
+ typingIndicator.style.display = 'none';
282
+ }
283
+ }
284
+ }
285
+
286
+ async function setupWebRTC() {
287
+ const config = __RTC_CONFIGURATION__;
288
+ peerConnection = new RTCPeerConnection(config);
289
+
290
+ try {
291
+ const stream = await navigator.mediaDevices.getUserMedia({
292
+ audio: true
293
+ });
294
+
295
+ setupAudioVisualization(stream);
296
+
297
+ stream.getTracks().forEach(track => {
298
+ peerConnection.addTrack(track, stream);
299
+ });
300
+
301
+ const dataChannel = peerConnection.createDataChannel('text');
302
+ dataChannel.onmessage = handleMessage;
303
+
304
+ const offer = await peerConnection.createOffer();
305
+ await peerConnection.setLocalDescription(offer);
306
+
307
+ await new Promise((resolve) => {
308
+ if (peerConnection.iceGatheringState === "complete") {
309
+ resolve();
310
+ } else {
311
+ const checkState = () => {
312
+ if (peerConnection.iceGatheringState === "complete") {
313
+ peerConnection.removeEventListener("icegatheringstatechange", checkState);
314
+ resolve();
315
+ }
316
+ };
317
+ peerConnection.addEventListener("icegatheringstatechange", checkState);
318
+ }
319
+ });
320
+
321
+ peerConnection.addEventListener('connectionstatechange', () => {
322
+ console.log('connectionstatechange', peerConnection.connectionState);
323
+ updateButtonState();
324
+ });
325
+
326
+ webrtc_id = Math.random().toString(36).substring(7);
327
+
328
+ const response = await fetch('/webrtc/offer', {
329
+ method: 'POST',
330
+ headers: { 'Content-Type': 'application/json' },
331
+ body: JSON.stringify({
332
+ sdp: peerConnection.localDescription.sdp,
333
+ type: peerConnection.localDescription.type,
334
+ webrtc_id: webrtc_id
335
+ })
336
+ });
337
+
338
+ const serverResponse = await response.json();
339
+ await peerConnection.setRemoteDescription(serverResponse);
340
+
341
+ eventSource = new EventSource('/outputs?webrtc_id=' + webrtc_id);
342
+ eventSource.addEventListener("output", (event) => {
343
+ const eventJson = JSON.parse(event.data);
344
+ console.log(eventJson);
345
+ messages.push(eventJson.message);
346
+ addMessage(eventJson.message.role, eventJson.audio ?? eventJson.message.content);
347
+ });
348
+ } catch (err) {
349
+ console.error('Error setting up WebRTC:', err);
350
+ }
351
+ }
352
+
353
+ function addMessage(role, content) {
354
+ const messageDiv = document.createElement('div');
355
+ messageDiv.classList.add('message', role);
356
+
357
+ if (role === 'user') {
358
+ // Create audio element for user messages
359
+ const audio = document.createElement('audio');
360
+ audio.controls = true;
361
+ audio.src = content;
362
+ messageDiv.appendChild(audio);
363
+ } else {
364
+ // Text content for assistant messages
365
+ messageDiv.textContent = content;
366
+ }
367
+
368
+ chatMessages.appendChild(messageDiv);
369
+ chatMessages.scrollTop = chatMessages.scrollHeight;
370
+ }
371
+
372
+ function stop() {
373
+ if (eventSource) {
374
+ eventSource.close();
375
+ eventSource = null;
376
+ }
377
+
378
+ if (animationFrame) {
379
+ cancelAnimationFrame(animationFrame);
380
+ }
381
+ if (audioContext) {
382
+ audioContext.close();
383
+ audioContext = null;
384
+ analyser = null;
385
+ audioSource = null;
386
+ }
387
+ if (peerConnection) {
388
+ if (peerConnection.getTransceivers) {
389
+ peerConnection.getTransceivers().forEach(transceiver => {
390
+ if (transceiver.stop) {
391
+ transceiver.stop();
392
+ }
393
+ });
394
+ }
395
+
396
+ if (peerConnection.getSenders) {
397
+ peerConnection.getSenders().forEach(sender => {
398
+ if (sender.track && sender.track.stop) sender.track.stop();
399
+ });
400
+ }
401
+ peerConnection.close();
402
+ }
403
+ updateButtonState();
404
+ audioLevel = 0;
405
+ }
406
+
407
+ startButton.addEventListener('click', () => {
408
+ if (!peerConnection || peerConnection.connectionState !== 'connected') {
409
+ setupWebRTC();
410
+ } else {
411
+ stop();
412
+ }
413
+ });
414
+ </script>
415
+ </body>
416
+
417
+ </html>
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ fastrtc[stopword]
2
+ python-dotenv
3
+ openai