|
import torch |
|
from transformers import WhisperProcessor, WhisperForConditionalGeneration |
|
from datasets import load_dataset |
|
from duckduckgo_search import DDGS |
|
from newspaper import Article |
|
import scipy |
|
from transformers import ( |
|
MT5Tokenizer, |
|
AdamW, |
|
MT5ForConditionalGeneration, |
|
pipeline |
|
) |
|
from transformers import VitsModel, AutoTokenizer |
|
import IPython.display as ipd |
|
import torch |
|
import numpy as np |
|
import gradio as gr |
|
import os |
|
|
|
class Webapp: |
|
def __init__(self): |
|
self.DEVICE = 0 if torch.cuda.is_available() else "cpu" |
|
self.REF_MODEL = 'google/mt5-small' |
|
self.MODEL_NAME = 'Ahmedasd/arabic-summarization-hhh-100-batches' |
|
self.model_id = "openai/whisper-base" |
|
self.tts_model_id = "SeyedAli/Arabic-Speech-synthesis" |
|
self.tts_model = VitsModel.from_pretrained(self.tts_model_id).to(self.DEVICE) |
|
self.tts_tokenizer = AutoTokenizer.from_pretrained(self.tts_model_id) |
|
|
|
self.summ_tokenizer = MT5Tokenizer.from_pretrained(self.REF_MODEL) |
|
self.summ_model = MT5ForConditionalGeneration.from_pretrained(self.MODEL_NAME).to(self.DEVICE) |
|
|
|
self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
|
|
|
|
self.stt_model = WhisperForConditionalGeneration.from_pretrained(self.model_id) |
|
self.stt_model.to(self.DEVICE) |
|
|
|
self.processor = WhisperProcessor.from_pretrained(self.model_id) |
|
self.forced_decoder_ids = self.processor.get_decoder_prompt_ids(language="arabic", task="transcribe") |
|
def speech_to_text(self, input): |
|
print('gradio audio type: ', type(input)) |
|
print('gradio audio: ', input) |
|
new_sample_rate = 16000 |
|
new_length = int(len(input[1]) * new_sample_rate / 48000) |
|
audio_sr_16000 = scipy.signal.resample(input[1], new_length) |
|
print('input audio16000: ', audio_sr_16000) |
|
input_features = self.processor(audio_sr_16000, sampling_rate=new_sample_rate, return_tensors="pt").input_features.to(self.DEVICE) |
|
predicted_ids = self.stt_model.generate(input_features, forced_decoder_ids=self.forced_decoder_ids) |
|
transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True) |
|
return transcription |
|
def get_articles(self, query, num): |
|
with DDGS(timeout=20) as ddgs: |
|
try: |
|
results = ddgs.news(query, max_results=num) |
|
urls = [r['url'] for r in results] |
|
print('successful connection!') |
|
except Exception as error: |
|
urls = ['https://www.bbc.com/arabic/media-65576589'] |
|
|
|
articles = [] |
|
for url in urls: |
|
article = Article(url) |
|
article.download() |
|
article.parse() |
|
articles.append(article.text.replace('\n','')) |
|
return articles |
|
def summarize(self, text, model): |
|
text_encoding = self.summ_tokenizer( |
|
text, |
|
max_length=512, |
|
padding='max_length', |
|
truncation=True, |
|
return_attention_mask=True, |
|
add_special_tokens=True, |
|
return_tensors='pt' |
|
) |
|
generated_ids = self.summ_model.generate( |
|
input_ids=text_encoding['input_ids'].to(self.DEVICE), |
|
attention_mask = text_encoding['attention_mask'].to(self.DEVICE), |
|
max_length=128, |
|
|
|
repetition_penalty=2.5, |
|
|
|
|
|
) |
|
|
|
preds = [self.summ_tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True) |
|
for gen_id in generated_ids |
|
] |
|
return "".join(preds) |
|
def summarize_articles(self, articles: int, model): |
|
summaries = [] |
|
for article in articles: |
|
summaries.append(self.summarize(article, model)) |
|
return summaries |
|
def text_to_speech(self, text): |
|
inputs = self.tts_tokenizer(text, return_tensors="pt").to(self.DEVICE) |
|
print('text_to_speech text: ', text) |
|
with torch.no_grad(): |
|
wav = self.tts_model(**inputs).waveform |
|
print('text_to_speech wav: ', wav) |
|
return {'wav':wav, 'rate':self.tts_model.config.sampling_rate} |
|
def topic_voice_to_summary_voices(self, topic_voice, number_articles): |
|
topic = self.speech_to_text(topic_voice) |
|
print('topic: ', topic) |
|
articles = self.get_articles(topic, number_articles) |
|
print('articles: ', articles) |
|
summaries = self.summarize_articles(articles, self.summ_model) |
|
print('summaries: ', summaries) |
|
voices_wav_rate = [self.text_to_speech(summary) for summary in summaries] |
|
|
|
return voices_wav_rate |
|
def run(self): |
|
with gr.Blocks(title = 'أخبار مسموعة', analytics_enabled=True, theme = gr.themes.Glass, css = 'dir: rtl;') as demo: |
|
gr.Markdown( |
|
""" |
|
# أخبار مسموعة |
|
اذكر الموضوع الذي تريد البحث عنه وسوف نخبرك بملخصات الأخبار بشأنه. |
|
""", rtl = True) |
|
intro_voice = gr.Audio(type='filepath', value = os.getcwd() + '/gradio intro.mp3', visible = False, autoplay = True) |
|
topic_voice = gr.Audio(type="numpy", sources = 'microphone', label ='سجل موضوع للبحث') |
|
num_articles = gr.Slider(minimum=1, maximum=10, value=1, step = 1, label = "عدد المقالات") |
|
output_audio = gr.Audio(streaming = True, autoplay = True, label = 'الملخصات') |
|
|
|
|
|
|
|
@topic_voice.stop_recording(inputs = [topic_voice, num_articles], outputs = output_audio) |
|
def get_summ_audio(topic_voice, num_articles): |
|
summ_voices = self.topic_voice_to_summary_voices(topic_voice, num_articles) |
|
m =15000 |
|
print('summ voices: ', summ_voices) |
|
print('wav: ') |
|
print('max: ', (np.array(summ_voices[0]['wav'][0].cpu()*m, dtype = np.int16)).max()) |
|
print('min: ', (np.array(summ_voices[0]['wav'][0].cpu()*m, dtype = np.int16)).min()) |
|
print('len: ', len(np.array(summ_voices[0]['wav'][0].cpu(), dtype = np.int16))) |
|
summ_audio = [(voice['rate'], np.squeeze(np.array(voice['wav'].cpu()*m, dtype = np.int16))) for voice in summ_voices] |
|
return summ_audio[0] |
|
return demo |
|
|
|
app = Webapp() |
|
app.run().launch() |
|
|