File size: 6,047 Bytes
83fcd99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
from duckduckgo_search import DDGS
from newspaper import Article
import scipy
from transformers import (
    MT5Tokenizer,
    AdamW,
    MT5ForConditionalGeneration,
    pipeline
)
from transformers import VitsModel, AutoTokenizer
import IPython.display as ipd
import torch
import numpy as np
import gradio as gr
import os

class Webapp:
  def __init__(self):
    self.DEVICE = 0 if torch.cuda.is_available() else "cpu"
    self.REF_MODEL = 'google/mt5-small'
    self.MODEL_NAME = 'Ahmedasd/arabic-summarization-hhh-100-batches'
    self.model_id = "openai/whisper-base"
    self.tts_model_id = "SeyedAli/Arabic-Speech-synthesis"
    self.tts_model = VitsModel.from_pretrained(self.tts_model_id).to(self.DEVICE)
    self.tts_tokenizer = AutoTokenizer.from_pretrained(self.tts_model_id)

    self.summ_tokenizer = MT5Tokenizer.from_pretrained(self.REF_MODEL)
    self.summ_model = MT5ForConditionalGeneration.from_pretrained(self.MODEL_NAME).to(self.DEVICE)
    
    self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

    

    self.stt_model = WhisperForConditionalGeneration.from_pretrained(self.model_id)
    self.stt_model.to(self.DEVICE)

    self.processor = WhisperProcessor.from_pretrained(self.model_id)
    self.forced_decoder_ids = self.processor.get_decoder_prompt_ids(language="arabic", task="transcribe")
  def speech_to_text(self, input):
    print('gradio audio type: ', type(input))
    print('gradio audio: ', input)
    new_sample_rate = 16000
    new_length = int(len(input[1]) * new_sample_rate / 48000)
    audio_sr_16000 = scipy.signal.resample(input[1], new_length)
    print('input audio16000: ', audio_sr_16000)
    input_features = self.processor(audio_sr_16000, sampling_rate=new_sample_rate, return_tensors="pt").input_features.to(self.DEVICE)
    predicted_ids = self.stt_model.generate(input_features, forced_decoder_ids=self.forced_decoder_ids)
    transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
    return transcription
  def get_articles(self, query, num):
    with DDGS(timeout=20) as ddgs:
      try:
        results = ddgs.news(query, max_results=num)
        urls = [r['url'] for r in results]
        print('successful connection!')
      except Exception as error:
        urls = ['https://www.bbc.com/arabic/media-65576589']
      
    articles = []
    for url in urls:
      article = Article(url)
      article.download()
      article.parse()
      articles.append(article.text.replace('\n',''))
    return articles
  def summarize(self, text, model):
    text_encoding = self.summ_tokenizer(
        text,
        max_length=512,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        add_special_tokens=True,
        return_tensors='pt'
    )
    generated_ids = self.summ_model.generate(
        input_ids=text_encoding['input_ids'].to(self.DEVICE),
        attention_mask = text_encoding['attention_mask'].to(self.DEVICE),
        max_length=128,
        # num_beams=2,
        repetition_penalty=2.5,
        # length_penalty=1.0,
        # early_stopping=True
    )

    preds = [self.summ_tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
            for gen_id in generated_ids
            ]
    return "".join(preds)
  def summarize_articles(self, articles: int, model):
    summaries = []
    for article in articles:
      summaries.append(self.summarize(article, model))
    return summaries
  def text_to_speech(self, text):
    inputs = self.tts_tokenizer(text, return_tensors="pt").to(self.DEVICE)
    print('text_to_speech text: ', text)
    with torch.no_grad():
        wav = self.tts_model(**inputs).waveform
    print('text_to_speech wav: ', wav)
    return {'wav':wav, 'rate':self.tts_model.config.sampling_rate}
  def topic_voice_to_summary_voices(self, topic_voice, number_articles):
    topic = self.speech_to_text(topic_voice)
    print('topic: ', topic)
    articles = self.get_articles(topic, number_articles)
    print('articles: ', articles)
    summaries = self.summarize_articles(articles, self.summ_model)
    print('summaries: ', summaries)
    voices_wav_rate = [self.text_to_speech(summary) for summary in summaries]

    return voices_wav_rate
  def run(self):
    with gr.Blocks(title = 'أخبار مسموعة', analytics_enabled=True, theme = gr.themes.Glass, css = 'dir: rtl;') as demo:
      gr.Markdown(
      """
      # أخبار مسموعة
      اذكر الموضوع الذي تريد البحث عنه وسوف نخبرك بملخصات الأخبار بشأنه.
      """, rtl = True)
      intro_voice = gr.Audio(type='filepath', value = os.getcwd() + '/gradio intro.mp3', visible = False, autoplay = True)
      topic_voice = gr.Audio(type="numpy", sources = 'microphone', label ='سجل موضوع للبحث')
      num_articles = gr.Slider(minimum=1, maximum=10, value=1, step = 1, label = "عدد المقالات")
      output_audio = gr.Audio(streaming = True, autoplay = True, label = 'الملخصات')
      
      # Events
      # generate summaries
      @topic_voice.stop_recording(inputs = [topic_voice, num_articles], outputs = output_audio)
      def get_summ_audio(topic_voice, num_articles):
        summ_voices = self.topic_voice_to_summary_voices(topic_voice, num_articles)
        m =15000
        print('summ voices: ', summ_voices)
        print('wav: ')
        print('max: ', (np.array(summ_voices[0]['wav'][0].cpu()*m, dtype = np.int16)).max())
        print('min: ', (np.array(summ_voices[0]['wav'][0].cpu()*m, dtype = np.int16)).min())
        print('len: ', len(np.array(summ_voices[0]['wav'][0].cpu(), dtype = np.int16)))
        summ_audio = [(voice['rate'], np.squeeze(np.array(voice['wav'].cpu()*m, dtype = np.int16))) for voice in summ_voices]
        return summ_audio[0] #only first
      return demo

app = Webapp()
app.run().launch()