Spaces:
Running
Running
import gradio as gr | |
import time | |
import openai | |
import json | |
import os | |
from transformers import pipeline | |
from transformers import AutoProcessor, AutoModelForCTC | |
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-large-robust-ft-libri-960h") | |
model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-large-robust-ft-libri-960h") | |
# asr_pipeline = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-large-robust-ft-libri-960h") | |
openai.api_key = os.environ.get('OPENAI_KEY') | |
def classify_audio(audio): | |
# Transcribe the audio to text | |
# audio_transcript = asr_pipeline(audio)["text"] | |
# audio_transcript = audio_transcript.lower() | |
input_values = processor(audio, return_tensors="pt", padding="longest").input_values | |
# retrieve logits | |
logits = model(input_values).logits | |
# take argmax and decode | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = processor.batch_decode(predicted_ids) | |
messages = [ | |
{"role": "system", "content": "Is this chat a scam, spam or is safe? Only answer in JSON format with 'classification': '' as string and 'reasons': '' as the most plausible reasons why. The reason should be explaning to the potential victim why the conversation is probably a scam"}, | |
{"role": "user", "content": transcription}, | |
] | |
# Call the OpenAI API to generate a response | |
response = openai.ChatCompletion.create( | |
model="gpt-4", # Replace with the actual GPT-4 model ID | |
messages=messages | |
) | |
# Extract the generated text | |
text = response.choices[0].message['content'] | |
text = json.loads(text) | |
# Get the decision and reasons from the JSON dictionary | |
decision = text["classification"] | |
reasons = text["reasons"] | |
# Return the transcription and the prediction as a dictionary | |
return transcription, decision, reasons | |
gr.Interface( | |
fn=classify_audio, | |
inputs=gr.inputs.Audio(source="upload", type="numpy"), | |
outputs=[ | |
gr.outputs.Textbox(label="Transcription"), | |
gr.outputs.Textbox(label="Classification"), | |
gr.outputs.Textbox(label="Reason"), | |
], | |
live=True | |
).launch() | |