import torch from transformers import WhisperProcessor, WhisperForConditionalGeneration from datasets import load_dataset from duckduckgo_search import DDGS from newspaper import Article import scipy from transformers import ( MT5Tokenizer, AdamW, MT5ForConditionalGeneration, pipeline ) from transformers import VitsModel, AutoTokenizer import IPython.display as ipd import torch import numpy as np import gradio as gr import os class Webapp: def __init__(self): self.DEVICE = 0 if torch.cuda.is_available() else "cpu" self.REF_MODEL = 'google/mt5-small' self.MODEL_NAME = 'Ahmedasd/arabic-summarization-hhh-100-batches' self.model_id = "openai/whisper-base" self.tts_model_id = "SeyedAli/Arabic-Speech-synthesis" self.tts_model = VitsModel.from_pretrained(self.tts_model_id).to(self.DEVICE) self.tts_tokenizer = AutoTokenizer.from_pretrained(self.tts_model_id) self.summ_tokenizer = MT5Tokenizer.from_pretrained(self.REF_MODEL) self.summ_model = MT5ForConditionalGeneration.from_pretrained(self.MODEL_NAME).to(self.DEVICE) self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 self.stt_model = WhisperForConditionalGeneration.from_pretrained(self.model_id) self.stt_model.to(self.DEVICE) self.processor = WhisperProcessor.from_pretrained(self.model_id) self.forced_decoder_ids = self.processor.get_decoder_prompt_ids(language="arabic", task="transcribe") def speech_to_text(self, input): print('gradio audio type: ', type(input)) print('gradio audio: ', input) new_sample_rate = 16000 new_length = int(len(input[1]) * new_sample_rate / 48000) audio_sr_16000 = scipy.signal.resample(input[1], new_length) print('input audio16000: ', audio_sr_16000) input_features = self.processor(audio_sr_16000, sampling_rate=new_sample_rate, return_tensors="pt").input_features.to(self.DEVICE) predicted_ids = self.stt_model.generate(input_features, forced_decoder_ids=self.forced_decoder_ids) transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True) return transcription def get_articles(self, query, num): with DDGS(timeout=20) as ddgs: try: results = ddgs.news(query, max_results=num) urls = [r['url'] for r in results] print('successful connection!') except Exception as error: urls = ['https://www.bbc.com/arabic/media-65576589'] articles = [] for url in urls: article = Article(url) article.download() article.parse() articles.append(article.text.replace('\n','')) return articles def summarize(self, text, model): text_encoding = self.summ_tokenizer( text, max_length=512, padding='max_length', truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors='pt' ) generated_ids = self.summ_model.generate( input_ids=text_encoding['input_ids'].to(self.DEVICE), attention_mask = text_encoding['attention_mask'].to(self.DEVICE), max_length=128, # num_beams=2, repetition_penalty=2.5, # length_penalty=1.0, # early_stopping=True ) preds = [self.summ_tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True) for gen_id in generated_ids ] return "".join(preds) def summarize_articles(self, articles: int, model): summaries = [] for article in articles: summaries.append(self.summarize(article, model)) return summaries def text_to_speech(self, text): inputs = self.tts_tokenizer(text, return_tensors="pt").to(self.DEVICE) print('text_to_speech text: ', text) with torch.no_grad(): wav = self.tts_model(**inputs).waveform print('text_to_speech wav: ', wav) return {'wav':wav, 'rate':self.tts_model.config.sampling_rate} def topic_voice_to_summary_voices(self, topic_voice, number_articles): topic = self.speech_to_text(topic_voice) print('topic: ', topic) articles = self.get_articles(topic, number_articles) print('articles: ', articles) summaries = self.summarize_articles(articles, self.summ_model) print('summaries: ', summaries) voices_wav_rate = [self.text_to_speech(summary) for summary in summaries] return voices_wav_rate def run(self): with gr.Blocks(title = 'أخبار مسموعة', analytics_enabled=True, theme = gr.themes.Glass, css = 'dir: rtl;') as demo: gr.Markdown( """ # أخبار مسموعة اذكر الموضوع الذي تريد البحث عنه وسوف نخبرك بملخصات الأخبار بشأنه. """, rtl = True) intro_voice = gr.Audio(type='filepath', value = os.getcwd() + '/gradio intro.mp3', visible = False, autoplay = True) topic_voice = gr.Audio(type="numpy", sources = 'microphone', label ='سجل موضوع للبحث') num_articles = gr.Slider(minimum=1, maximum=10, value=1, step = 1, label = "عدد المقالات") output_audio = gr.Audio(streaming = True, autoplay = True, label = 'الملخصات') # Events # generate summaries @topic_voice.stop_recording(inputs = [topic_voice, num_articles], outputs = output_audio) def get_summ_audio(topic_voice, num_articles): summ_voices = self.topic_voice_to_summary_voices(topic_voice, num_articles) m =15000 print('summ voices: ', summ_voices) print('wav: ') print('max: ', (np.array(summ_voices[0]['wav'][0].cpu()*m, dtype = np.int16)).max()) print('min: ', (np.array(summ_voices[0]['wav'][0].cpu()*m, dtype = np.int16)).min()) print('len: ', len(np.array(summ_voices[0]['wav'][0].cpu(), dtype = np.int16))) summ_audio = [(voice['rate'], np.squeeze(np.array(voice['wav'].cpu()*m, dtype = np.int16))) for voice in summ_voices] return summ_audio[0] #only first return demo app = Webapp() app.run().launch()