File size: 6,487 Bytes
83fcd99 f62233e 83fcd99 f62233e 83fcd99 f62233e 83fcd99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
from duckduckgo_search import DDGS
from newspaper import Article
import scipy
import random
from transformers import (
MT5Tokenizer,
AdamW,
MT5ForConditionalGeneration,
pipeline
)
from transformers import VitsModel, AutoTokenizer
import IPython.display as ipd
import torch
import numpy as np
import gradio as gr
import os
class Webapp:
def __init__(self):
self.DEVICE = 0 if torch.cuda.is_available() else "cpu"
self.REF_MODEL = 'google/mt5-small'
self.MODEL_NAME = 'Ahmedasd/arabic-summarization-hhh-100-batches'
self.model_id = "openai/whisper-tiny"
self.tts_model_id = "SeyedAli/Arabic-Speech-synthesis"
self.tts_model = VitsModel.from_pretrained(self.tts_model_id).to(self.DEVICE)
self.tts_tokenizer = AutoTokenizer.from_pretrained(self.tts_model_id)
self.summ_tokenizer = MT5Tokenizer.from_pretrained(self.REF_MODEL)
self.summ_model = MT5ForConditionalGeneration.from_pretrained(self.MODEL_NAME).to(self.DEVICE)
self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
self.stt_model = WhisperForConditionalGeneration.from_pretrained(self.model_id)
self.stt_model.to(self.DEVICE)
self.processor = WhisperProcessor.from_pretrained(self.model_id)
self.forced_decoder_ids = self.processor.get_decoder_prompt_ids(language="arabic", task="transcribe")
def speech_to_text(self, input):
print('gradio audio type: ', type(input))
print('gradio audio: ', input)
new_sample_rate = 16000
new_length = int(len(input[1]) * new_sample_rate / 48000)
audio_sr_16000 = scipy.signal.resample(input[1], new_length)
print('input audio16000: ', audio_sr_16000)
input_features = self.processor(audio_sr_16000, sampling_rate=new_sample_rate, return_tensors="pt").input_features.to(self.DEVICE)
predicted_ids = self.stt_model.generate(input_features, forced_decoder_ids=self.forced_decoder_ids)
transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
return transcription
def get_articles(self, query, num):
with DDGS(timeout=20) as ddgs:
try:
results = ddgs.news(query, max_results=num)
urls = [r['url'] for r in results]
print('successful connection!')
except Exception as error:
examples = ['https://www.bbc.com/arabic/media-65576589',
'https://www.bbc.com/arabic/articles/czr8dk93231o',
'https://www.bbc.com/arabic/articles/c0jyd2yweplo',
'https://www.bbc.com/arabic/articles/cnd8wwdyyzko',
'https://www.bbc.com/arabic/articles/c3gyxymp0z1o',
'https://www.bbc.com/arabic/articles/c3g28kl8zj4o'
]
urls = [random.choice(examples)]
articles = []
for url in urls:
article = Article(url)
article.download()
article.parse()
articles.append(article.text.replace('\n',''))
return articles
def summarize(self, text, model):
text_encoding = self.summ_tokenizer(
text,
max_length=512,
padding='max_length',
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors='pt'
)
generated_ids = self.summ_model.generate(
input_ids=text_encoding['input_ids'].to(self.DEVICE),
attention_mask = text_encoding['attention_mask'].to(self.DEVICE),
max_length=128,
# num_beams=2,
repetition_penalty=2.5,
# length_penalty=1.0,
# early_stopping=True
)
preds = [self.summ_tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
for gen_id in generated_ids
]
return "".join(preds)
def summarize_articles(self, articles: int, model):
summaries = []
for article in articles:
summaries.append(self.summarize(article, model))
return summaries
def text_to_speech(self, text):
inputs = self.tts_tokenizer(text, return_tensors="pt").to(self.DEVICE)
print('text_to_speech text: ', text)
with torch.no_grad():
wav = self.tts_model(**inputs).waveform
print('text_to_speech wav: ', wav)
return {'wav':wav, 'rate':self.tts_model.config.sampling_rate}
def topic_voice_to_summary_voices(self, topic_voice, number_articles):
topic = self.speech_to_text(topic_voice)
print('topic: ', topic)
articles = self.get_articles(topic, number_articles)
print('articles: ', articles)
summaries = self.summarize_articles(articles, self.summ_model)
print('summaries: ', summaries)
voices_wav_rate = [self.text_to_speech(summary) for summary in summaries]
return voices_wav_rate
def run(self):
with gr.Blocks(title = 'أخبار مسموعة', analytics_enabled=True, theme = gr.themes.Glass, css = 'dir: rtl;') as demo:
gr.Markdown(
"""
# أخبار مسموعة
اذكر الموضوع الذي تريد البحث عنه وسوف نخبرك بملخصات الأخبار بشأنه.
""", rtl = True)
intro_voice = gr.Audio(type='filepath', value = os.getcwd() + '/gradio intro.mp3', visible = False, autoplay = True)
topic_voice = gr.Audio(type="numpy", sources = 'microphone', label ='سجل موضوع للبحث')
num_articles = gr.Slider(minimum=1, maximum=10, value=1, step = 1, label = "عدد المقالات")
output_audio = gr.Audio(streaming = True, autoplay = True, label = 'الملخصات')
# Events
# generate summaries
@topic_voice.stop_recording(inputs = [topic_voice, num_articles], outputs = output_audio)
def get_summ_audio(topic_voice, num_articles):
summ_voices = self.topic_voice_to_summary_voices(topic_voice, num_articles)
m =15000
print('summ voices: ', summ_voices)
print('wav: ')
print('max: ', (np.array(summ_voices[0]['wav'][0].cpu()*m, dtype = np.int16)).max())
print('min: ', (np.array(summ_voices[0]['wav'][0].cpu()*m, dtype = np.int16)).min())
print('len: ', len(np.array(summ_voices[0]['wav'][0].cpu(), dtype = np.int16)))
summ_audio = [(voice['rate'], np.squeeze(np.array(voice['wav'].cpu()*m, dtype = np.int16))) for voice in summ_voices]
return summ_audio[0] #only first
return demo
app = Webapp()
app.run().launch()
|