veetla commited on
Commit
54d9d69
1 Parent(s): a13da65

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +139 -0
utils.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Some utility functions for the app."""
2
+ from base64 import b64encode
3
+ from io import BytesIO
4
+
5
+ from gtts import gTTS
6
+ from mtranslate import translate
7
+ from speech_recognition import AudioFile, Recognizer
8
+ from transformers import (BlenderbotSmallForConditionalGeneration,
9
+ BlenderbotSmallTokenizer)
10
+
11
+
12
+ def stt(audio: object, language: str) -> str:
13
+ """Converts speech to text.
14
+ Args:
15
+ audio: record of user speech
16
+ Returns:
17
+ text (str): recognized speech of user
18
+ """
19
+
20
+ # Create a Recognizer object
21
+ r = Recognizer()
22
+ # Open the audio file
23
+ with AudioFile(audio) as source:
24
+ # Listen for the data (load audio to memory)
25
+ audio_data = r.record(source)
26
+ # Transcribe the audio using Google's speech-to-text API
27
+ text = r.recognize_google(audio_data, language=language)
28
+ return text
29
+
30
+
31
+ def to_en_translation(text: str, language: str) -> str:
32
+ """Translates text from specified language to English.
33
+ Args:
34
+ text (str): input text
35
+ language (str): desired language
36
+ Returns:
37
+ str: translated text
38
+ """
39
+ return translate(text, "en", language)
40
+
41
+
42
+ def from_en_translation(text: str, language: str) -> str:
43
+ """Translates text from english to specified language.
44
+ Args:
45
+ text (str): input text
46
+ language (str): desired language
47
+ Returns:
48
+ str: translated text
49
+ """
50
+ return translate(text, language, "en")
51
+
52
+
53
+ class TextGenerationPipeline:
54
+ """Pipeline for text generation of blenderbot model.
55
+ Returns:
56
+ str: generated text
57
+ """
58
+
59
+ # load tokenizer and the model
60
+ model_name = "facebook/blenderbot_small-90M"
61
+ tokenizer = BlenderbotSmallTokenizer.from_pretrained(model_name)
62
+ model = BlenderbotSmallForConditionalGeneration.from_pretrained(model_name)
63
+
64
+ def __init__(self, **kwargs):
65
+ """Specififying text generation parameters.
66
+ For example: max_length=100 which generates text shorter than
67
+ 100 tokens. Visit:
68
+ https://huggingface.co/docs/transformers/main_classes/text_generation
69
+ for more parameters
70
+ """
71
+ self.__dict__.update(kwargs)
72
+
73
+ def preprocess(self, text) -> str:
74
+ """Tokenizes input text.
75
+ Args:
76
+ text (str): user specified text
77
+ Returns:
78
+ torch.Tensor (obj): text representation as tensors
79
+ """
80
+ return self.tokenizer(text, return_tensors="pt")
81
+
82
+ def postprocess(self, outputs) -> str:
83
+ """Converts tensors into text.
84
+ Args:
85
+ outputs (torch.Tensor obj): model text generation output
86
+ Returns:
87
+ str: generated text
88
+ """
89
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
90
+
91
+ def __call__(self, text: str) -> str:
92
+ """Generates text from input text.
93
+ Args:
94
+ text (str): user specified text
95
+ Returns:
96
+ str: generated text
97
+ """
98
+ tokenized_text = self.preprocess(text)
99
+ output = self.model.generate(**tokenized_text, **self.__dict__)
100
+ return self.postprocess(output)
101
+
102
+
103
+ def tts(text: str, language: str) -> object:
104
+ """Converts text into audio object.
105
+ Args:
106
+ text (str): generated answer of bot
107
+ Returns:
108
+ object: text to speech object
109
+ """
110
+ return gTTS(text=text, lang=language, slow=False)
111
+
112
+
113
+ def tts_to_bytesio(tts_object: object) -> bytes:
114
+ """Converts tts object to bytes.
115
+ Args:
116
+ tts_object (object): audio object obtained from gtts
117
+ Returns:
118
+ bytes: audio bytes
119
+ """
120
+ bytes_object = BytesIO()
121
+ tts_object.write_to_fp(bytes_object)
122
+ bytes_object.seek(0)
123
+ return bytes_object.getvalue()
124
+
125
+
126
+ def html_audio_autoplay(bytes: bytes) -> object:
127
+ """Creates html object for autoplaying audio at gradio app.
128
+ Args:
129
+ bytes (bytes): audio bytes
130
+ Returns:
131
+ object: html object that provides audio autoplaying
132
+ """
133
+ b64 = b64encode(bytes).decode()
134
+ html = f"""
135
+ <audio controls autoplay>
136
+ <source src="data:audio/wav;base64,{b64}" type="audio/wav">
137
+ </audio>
138
+ """
139
+ return html