File size: 2,327 Bytes
1324088
 
 
 
596d25f
1324088
1003643
7c07f39
 
 
 
 
 
 
 
1324088
 
 
1003643
1324088
1003643
1324088
1003643
48c53a8
 
 
 
 
1003643
48c53a8
 
1003643
1324088
1003643
 
 
 
 
 
 
1324088
 
 
 
 
 
1003643
1324088
 
 
 
 
1003643
1324088
 
1003643
1324088
 
1003643
 
 
 
 
1324088
 
1003643
 
 
 
ddcc364
 
1324088
 
 
 
1003643
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os

import gradio as gr
import numpy as np
import torch
from groq import Groq
from transformers import pipeline
from transformers.utils import is_flash_attn_2_available

transcriber = pipeline("automatic-speech-recognition",
                       model="openai/whisper-large-v3",
                       torch_dtype=torch.float16,
                       device="cuda:0",
                       model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"},
                      )

groq_client = Groq(api_key=os.getenv('GROQ_API_KEY'))

def transcribe(stream, new_chunk):
    """
    Transcribes using whisper
    """
    sr, y = new_chunk

    # Convert stereo to mono if necessary
    if y.ndim == 2 and y.shape[1] == 2:
        y = y.mean(axis=1)  # Averaging both channels if stereo
    
    y = y.astype(np.float32)

    # Normalization
    y /= np.max(np.abs(y))

    if stream is not None:
        stream = np.concatenate([stream, y])
    else:
        stream = y
    return stream, transcriber({"sampling_rate": sr, "raw": stream})["text"]

def autocomplete(text):
    """
    Autocomplete the text using Gemma.
    """
    if text != "":
        response = groq_client.chat.completions.create(
            model='gemma-7b-it',
            messages=[{"role": "system", "content": "You are a friendly assistant named Gemma."},
                      {"role": "user", "content": text}]
            )
            
        return response.choices[0].message.content

def process_audio(input_audio, new_chunk):
    """
    Process the audio input by transcribing and completing the sentences.
    Accumulate results to return to Gradio interface.
    """

    stream, transcription = transcribe(input_audio, new_chunk)
    text = autocomplete(transcription)

    print (transcription, text)
    return stream, text


demo = gr.Interface(
    fn = process_audio,
    inputs = ["state", gr.Audio(sources=["microphone"], streaming=True)],
    outputs = ["state", gr.Markdown()],
    title="Hey Gemma ☎️",
    description="Powered by [whisper-base-en](https://huggingface.co/openai/whisper-base.en), and [gemma-7b-it](https://huggingface.co/google/gemma-7b-it) (via [Groq](https://groq.com/))",
    live=True,
    allow_flagging="never"
)

demo.launch()