Oysiyl commited on
Commit
dbf0dc1
1 Parent(s): e6e82cb

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +104 -0
handler.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Text, Any
2
+ import re
3
+ from transformers import SpeechT5ForTextToSpeech
4
+ from transformers import SpeechT5Processor
5
+ from transformers import SpeechT5HifiGan
6
+ import soundfile
7
+ import torch
8
+ import numpy as np
9
+
10
+ # set device
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+ if device.type != 'cuda':
13
+ raise ValueError("need to run on GPU")
14
+ # set mixed precision dtype
15
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
16
+
17
+
18
+ class EndpointHandler():
19
+ def __init__(self, path=""):
20
+ # Load all required models
21
+ self.model_id = "Oysiyl/speecht5_tts_common_voice_uk"
22
+ self.model = SpeechT5ForTextToSpeech.from_pretrained(self.model_id, torch_dtype=dtype).to(device)
23
+ self.processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
24
+ self.vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
25
+ self.speaker_embeddings = torch.tensor(np.load("embed.npy"), dtype=dtype).to(device)
26
+
27
+ @staticmethod
28
+ def remove_special_characters_s(text: Text) -> Text:
29
+ chars_to_remove_regex = '[\…\–\"\“\%\‘\”\�\»\«\„\`\'́]'
30
+ # remove special characters
31
+ text = re.sub(chars_to_remove_regex, '', text)
32
+ text = re.sub("՚", "'", text)
33
+ text = re.sub("’", "'", text)
34
+ text = re.sub(r'ы', 'и', text)
35
+ text = text.lower()
36
+ return text
37
+
38
+ @staticmethod
39
+ def cyrillic_to_latin(text: Text) -> Text:
40
+ replacements = [
41
+ ('а', 'a'),
42
+ ('б', 'b'),
43
+ ('в', 'v'),
44
+ ('г', 'h'),
45
+ ('д', 'd'),
46
+ ('е', 'e'),
47
+ ('ж', 'zh'),
48
+ ('з', 'z'),
49
+ ('и', 'y'),
50
+ ('й', 'j'),
51
+ ('к', 'k'),
52
+ ('л', 'l'),
53
+ ('м', 'm'),
54
+ ('н', 'n'),
55
+ ('о', 'o'),
56
+ ('п', 'p'),
57
+ ('р', 'r'),
58
+ ('с', 's'),
59
+ ('т', 't'),
60
+ ('у', 'u'),
61
+ ('ф', 'f'),
62
+ ('х', 'h'),
63
+ ('ц', 'ts'),
64
+ ('ч', 'ch'),
65
+ ('ш', 'sh'),
66
+ ('щ', 'sch'),
67
+ ('ь', "'"),
68
+ ('ю', 'ju'),
69
+ ('я', 'ja'),
70
+ ('є', 'je'),
71
+ ('і', 'i'),
72
+ ('ї', 'ji'),
73
+ ('ґ', 'g')
74
+ ]
75
+
76
+ for src, dst in replacements:
77
+ text = text.replace(src, dst)
78
+ return text
79
+
80
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
81
+ """
82
+ :param data: A dictionary contains `inputs`.
83
+ :return: A dictionary with `image` field contains image in base64.
84
+ """
85
+ text = data.pop("inputs", None)
86
+
87
+ # Check if text is not provided
88
+ if text is None:
89
+ return {"error": "Please provide a text."}
90
+
91
+ # run inference pipeline
92
+ text = self.remove_special_characters_s(text)
93
+ text = self.cyrillic_to_latin(text)
94
+ input_ids = self.processor(text=text, return_tensors="pt")['input_ids'].to(device)
95
+ spectrogram = self.model.generate_speech(input_ids, self.speaker_embeddings)
96
+ with torch.no_grad():
97
+ speech = self.vocoder(spectrogram)
98
+ if device.type != 'cuda':
99
+ out = speech.numpy()
100
+ else:
101
+ out = speech.cpu().numpy()
102
+
103
+ # return output audio in numpy format
104
+ return out