Raghavan1988 commited on
Commit
b67fe1a
1 Parent(s): f86940b

Adding the predict method from facebook/seamless_m4t

Browse files
Files changed (1) hide show
  1. app.py +48 -0
app.py CHANGED
@@ -24,6 +24,54 @@ DEFAULT_TARGET_LANGUAGE = "English"
24
  AUDIO_SAMPLE_RATE = 16000.0
25
  MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def process_image_with_openai(image):
28
  image_data = convert_image_to_required_format(image)
29
  openai_api_key = config('OPENAI_API_KEY') # Make sure to have this in your .env file
 
24
  AUDIO_SAMPLE_RATE = 16000.0
25
  MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
26
 
27
+
28
+ def predict(
29
+ task_name: str,
30
+ audio_source: str,
31
+ input_audio_mic: str | None,
32
+ input_audio_file: str | None,
33
+ input_text: str | None,
34
+ source_language: str | None,
35
+ target_language: str,
36
+ ) -> tuple[tuple[int, np.ndarray] | None, str]:
37
+ task_name = task_name.split()[0]
38
+ source_language_code = LANGUAGE_NAME_TO_CODE[source_language] if source_language else None
39
+ target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
40
+
41
+ if task_name in ["S2ST", "S2TT", "ASR"]:
42
+ if audio_source == "microphone":
43
+ input_data = input_audio_mic
44
+ else:
45
+ input_data = input_audio_file
46
+
47
+ arr, org_sr = torchaudio.load(input_data)
48
+ new_arr = torchaudio.functional.resample(arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE)
49
+ max_length = int(MAX_INPUT_AUDIO_LENGTH * AUDIO_SAMPLE_RATE)
50
+ if new_arr.shape[1] > max_length:
51
+ new_arr = new_arr[:, :max_length]
52
+ gr.Warning(f"Input audio is too long. Only the first {MAX_INPUT_AUDIO_LENGTH} seconds is used.")
53
+
54
+
55
+ input_data = processor(audios = new_arr, sampling_rate=AUDIO_SAMPLE_RATE, return_tensors="pt").to(device)
56
+ else:
57
+ input_data = processor(text = input_text, src_lang=source_language_code, return_tensors="pt").to(device)
58
+
59
+
60
+ if task_name in ["S2TT", "T2TT"]:
61
+ tokens_ids = model.generate(**input_data, generate_speech=False, tgt_lang=target_language_code, num_beams=5, do_sample=True)[0].cpu().squeeze().detach().tolist()
62
+ else:
63
+ output = model.generate(**input_data, return_intermediate_token_ids=True, tgt_lang=target_language_code, num_beams=5, do_sample=True, spkr_id=LANG_TO_SPKR_ID[target_language_code][0])
64
+
65
+ waveform = output.waveform.cpu().squeeze().detach().numpy()
66
+ tokens_ids = output.sequences.cpu().squeeze().detach().tolist()
67
+
68
+ text_out = processor.decode(tokens_ids, skip_special_tokens=True)
69
+
70
+ if task_name in ["S2ST", "T2ST"]:
71
+ return (AUDIO_SAMPLE_RATE, waveform), text_out
72
+ else:
73
+ return None, text_out
74
+
75
  def process_image_with_openai(image):
76
  image_data = convert_image_to_required_format(image)
77
  openai_api_key = config('OPENAI_API_KEY') # Make sure to have this in your .env file