Ahmedasd's picture
first commit
83fcd99
raw
history blame
6.05 kB
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
from duckduckgo_search import DDGS
from newspaper import Article
import scipy
from transformers import (
MT5Tokenizer,
AdamW,
MT5ForConditionalGeneration,
pipeline
)
from transformers import VitsModel, AutoTokenizer
import IPython.display as ipd
import torch
import numpy as np
import gradio as gr
import os
class Webapp:
def __init__(self):
self.DEVICE = 0 if torch.cuda.is_available() else "cpu"
self.REF_MODEL = 'google/mt5-small'
self.MODEL_NAME = 'Ahmedasd/arabic-summarization-hhh-100-batches'
self.model_id = "openai/whisper-base"
self.tts_model_id = "SeyedAli/Arabic-Speech-synthesis"
self.tts_model = VitsModel.from_pretrained(self.tts_model_id).to(self.DEVICE)
self.tts_tokenizer = AutoTokenizer.from_pretrained(self.tts_model_id)
self.summ_tokenizer = MT5Tokenizer.from_pretrained(self.REF_MODEL)
self.summ_model = MT5ForConditionalGeneration.from_pretrained(self.MODEL_NAME).to(self.DEVICE)
self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
self.stt_model = WhisperForConditionalGeneration.from_pretrained(self.model_id)
self.stt_model.to(self.DEVICE)
self.processor = WhisperProcessor.from_pretrained(self.model_id)
self.forced_decoder_ids = self.processor.get_decoder_prompt_ids(language="arabic", task="transcribe")
def speech_to_text(self, input):
print('gradio audio type: ', type(input))
print('gradio audio: ', input)
new_sample_rate = 16000
new_length = int(len(input[1]) * new_sample_rate / 48000)
audio_sr_16000 = scipy.signal.resample(input[1], new_length)
print('input audio16000: ', audio_sr_16000)
input_features = self.processor(audio_sr_16000, sampling_rate=new_sample_rate, return_tensors="pt").input_features.to(self.DEVICE)
predicted_ids = self.stt_model.generate(input_features, forced_decoder_ids=self.forced_decoder_ids)
transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
return transcription
def get_articles(self, query, num):
with DDGS(timeout=20) as ddgs:
try:
results = ddgs.news(query, max_results=num)
urls = [r['url'] for r in results]
print('successful connection!')
except Exception as error:
urls = ['https://www.bbc.com/arabic/media-65576589']
articles = []
for url in urls:
article = Article(url)
article.download()
article.parse()
articles.append(article.text.replace('\n',''))
return articles
def summarize(self, text, model):
text_encoding = self.summ_tokenizer(
text,
max_length=512,
padding='max_length',
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors='pt'
)
generated_ids = self.summ_model.generate(
input_ids=text_encoding['input_ids'].to(self.DEVICE),
attention_mask = text_encoding['attention_mask'].to(self.DEVICE),
max_length=128,
# num_beams=2,
repetition_penalty=2.5,
# length_penalty=1.0,
# early_stopping=True
)
preds = [self.summ_tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
for gen_id in generated_ids
]
return "".join(preds)
def summarize_articles(self, articles: int, model):
summaries = []
for article in articles:
summaries.append(self.summarize(article, model))
return summaries
def text_to_speech(self, text):
inputs = self.tts_tokenizer(text, return_tensors="pt").to(self.DEVICE)
print('text_to_speech text: ', text)
with torch.no_grad():
wav = self.tts_model(**inputs).waveform
print('text_to_speech wav: ', wav)
return {'wav':wav, 'rate':self.tts_model.config.sampling_rate}
def topic_voice_to_summary_voices(self, topic_voice, number_articles):
topic = self.speech_to_text(topic_voice)
print('topic: ', topic)
articles = self.get_articles(topic, number_articles)
print('articles: ', articles)
summaries = self.summarize_articles(articles, self.summ_model)
print('summaries: ', summaries)
voices_wav_rate = [self.text_to_speech(summary) for summary in summaries]
return voices_wav_rate
def run(self):
with gr.Blocks(title = 'أخبار مسموعة', analytics_enabled=True, theme = gr.themes.Glass, css = 'dir: rtl;') as demo:
gr.Markdown(
"""
# أخبار مسموعة
اذكر الموضوع الذي تريد البحث عنه وسوف نخبرك بملخصات الأخبار بشأنه.
""", rtl = True)
intro_voice = gr.Audio(type='filepath', value = os.getcwd() + '/gradio intro.mp3', visible = False, autoplay = True)
topic_voice = gr.Audio(type="numpy", sources = 'microphone', label ='سجل موضوع للبحث')
num_articles = gr.Slider(minimum=1, maximum=10, value=1, step = 1, label = "عدد المقالات")
output_audio = gr.Audio(streaming = True, autoplay = True, label = 'الملخصات')
# Events
# generate summaries
@topic_voice.stop_recording(inputs = [topic_voice, num_articles], outputs = output_audio)
def get_summ_audio(topic_voice, num_articles):
summ_voices = self.topic_voice_to_summary_voices(topic_voice, num_articles)
m =15000
print('summ voices: ', summ_voices)
print('wav: ')
print('max: ', (np.array(summ_voices[0]['wav'][0].cpu()*m, dtype = np.int16)).max())
print('min: ', (np.array(summ_voices[0]['wav'][0].cpu()*m, dtype = np.int16)).min())
print('len: ', len(np.array(summ_voices[0]['wav'][0].cpu(), dtype = np.int16)))
summ_audio = [(voice['rate'], np.squeeze(np.array(voice['wav'].cpu()*m, dtype = np.int16))) for voice in summ_voices]
return summ_audio[0] #only first
return demo
app = Webapp()
app.run().launch()