File size: 2,144 Bytes
2029b71
 
 
 
 
 
 
 
 
51bc0a8
 
64daabf
 
2029b71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e08f207
 
 
3d861be
2029b71
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import csv

import matplotlib.pyplot as plt
from IPython.display import Audio
from scipy.io import wavfile

import os

import gradio as gr

# Load the model.
model = hub.load('https://tfhub.dev/google/yamnet/1')

# Find the name of the class with the top score when mean-aggregated across frames.
def class_names_from_csv(class_map_csv_text):
  """Returns list of class names corresponding to score vector."""
  class_names = []
  with tf.io.gfile.GFile(class_map_csv_text) as csvfile:
    reader = csv.DictReader(csvfile)
    for row in reader:
      class_names.append(row['display_name'])

  return class_names

class_map_path = model.class_map_path().numpy()
class_names = class_names_from_csv(class_map_path)


def ensure_sample_rate(original_sample_rate, waveform,
                       desired_sample_rate=16000):
  """Resample waveform if required."""
  if original_sample_rate != desired_sample_rate:
    desired_length = int(round(float(len(waveform)) /
                               original_sample_rate * desired_sample_rate))
    waveform = scipy.signal.resample(waveform, desired_length)
  return desired_sample_rate, waveform
 
os.system("wget https://storage.googleapis.com/audioset/miaow_16k.wav")
 
def inference(audio): 
   # wav_file_name = 'speech_whistling2.wav'
  wav_file_name = audio
  sample_rate, wav_data = wavfile.read(wav_file_name, 'rb')
  sample_rate, wav_data = ensure_sample_rate(sample_rate, wav_data)
 
  waveform = wav_data / tf.int16.max
  
  # Run the model, check the output.
  scores, embeddings, spectrogram = model(waveform)
  
  scores_np = scores.numpy()
  spectrogram_np = spectrogram.numpy()
  infered_class = class_names[scores_np.mean(axis=0).argmax()]  
  
  return  f'The main sound is: {infered_class}'

examples=[['miaow_16k.wav']]
title="yamnet"
description="An audio event classifier trained on the AudioSet dataset to predict audio events from the AudioSet ontology."
gr.Interface(inference,gr.inputs.Audio(type="filepath"),"text",examples=examples,title=title,description=description).launch(enable_queue=True)