badrex commited on
Commit
b04a244
·
1 Parent(s): 07a50af

fix AttributeError with .shape

Browse files
Files changed (1) hide show
  1. app.py +13 -6
app.py CHANGED
@@ -15,13 +15,20 @@ dialect_mapping = {
15
  "Maghrebi": "Maghrebi Arabic"
16
  }
17
 
18
- def predict_dialect(audio, sr):
 
 
 
 
 
 
 
19
  # Process the audio input
20
- if len(audio.shape) > 1:
21
- audio = audio.mean(axis=1) # Convert stereo to mono
22
 
23
  # Classify the dialect
24
- predictions = classifier({"sampling_rate": sr, "raw": audio})
25
 
26
  # Format results for display
27
  results = {}
@@ -34,7 +41,7 @@ def predict_dialect(audio, sr):
34
  # Create the Gradio interface
35
  demo = gr.Interface(
36
  fn=predict_dialect,
37
- inputs=gr.Audio(type="numpy", label="Upload or Record Audio"),
38
  outputs=gr.Label(num_top_classes=5, label="Predicted Dialect"),
39
  title="Arabic Dialect Identifier",
40
  description="""This demo identifies Arabic dialects from speech audio.
@@ -49,4 +56,4 @@ demo = gr.Interface(
49
  )
50
 
51
  # Launch the app
52
- demo.launch()
 
15
  "Maghrebi": "Maghrebi Arabic"
16
  }
17
 
18
+ def predict_dialect(audio):
19
+ # The audio input from Gradio is a tuple of (sample_rate, audio_array)
20
+ if isinstance(audio, tuple) and len(audio) == 2:
21
+ sr, audio_array = audio
22
+ else:
23
+ # Handle error case
24
+ return {"Error": 1.0}
25
+
26
  # Process the audio input
27
+ if len(audio_array.shape) > 1:
28
+ audio_array = audio_array.mean(axis=1) # Convert stereo to mono
29
 
30
  # Classify the dialect
31
+ predictions = classifier({"sampling_rate": sr, "raw": audio_array})
32
 
33
  # Format results for display
34
  results = {}
 
41
  # Create the Gradio interface
42
  demo = gr.Interface(
43
  fn=predict_dialect,
44
+ inputs=gr.Audio(), # Simplified audio input
45
  outputs=gr.Label(num_top_classes=5, label="Predicted Dialect"),
46
  title="Arabic Dialect Identifier",
47
  description="""This demo identifies Arabic dialects from speech audio.
 
56
  )
57
 
58
  # Launch the app
59
+ demo.launch()