Spaces:
Sleeping
Sleeping
commit all
Browse files- app.py +362 -29
- flagged/log.csv +2 -0
- hid512_decoder_att_epoch_20.pt +3 -0
- hid512_encoder_att_epoch_20.pt +3 -0
- requirements.txt +5 -1
- temp.ipynb +569 -0
- vocab_source.pkl +3 -0
- vocab_target.pkl +3 -0
app.py
CHANGED
@@ -1,49 +1,382 @@
|
|
1 |
import gradio as gr
|
2 |
-
from transformers import pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
envit5_translater = pipeline("translation", model="VietAI/envit5-translation")
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
def envit5_translation(text):
|
7 |
res = envit5_translater(
|
8 |
text,
|
9 |
max_length=512,
|
10 |
early_stopping=True,
|
11 |
-
)[0][
|
|
|
|
|
12 |
return res
|
13 |
|
14 |
-
def my_translation(text):
|
15 |
-
return "My Translation"
|
16 |
-
|
17 |
-
def finetune_BERT(text):
|
18 |
-
return "BERT"
|
19 |
|
20 |
def translation(text):
|
21 |
-
output1 =
|
22 |
output2 = envit5_translation(text)
|
23 |
-
output3 = finetune_BERT(text)
|
24 |
-
|
25 |
-
return (output1, output2
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
Multi-domain Translation Between English and Vietnamese
|
31 |
-
Using VietAI Translation
|
32 |
-
</center>
|
33 |
-
</p>
|
34 |
-
"""
|
35 |
-
examples = [
|
36 |
-
["Dear God, thank you for granting us the evergreen garden of this world", "en->vi"],
|
37 |
-
["Thuốc này đã bị cấm sử dụng trong ngành thú y tại Ấn Độ.", "vi->en"]
|
38 |
-
]
|
39 |
|
40 |
demo = gr.Interface(
|
|
|
41 |
fn=translation,
|
42 |
title="Co Gai Mo Duong",
|
43 |
-
description=
|
|
|
|
|
44 |
examples=examples,
|
45 |
-
inputs=
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
-
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import pipeline
|
3 |
+
import re
|
4 |
+
import pickle
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torchtext.transforms import PadTransform
|
8 |
+
from torch.utils.data import Dataset, DataLoader
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from tqdm import tqdm
|
11 |
+
from underthesea import word_tokenize, text_normalize
|
12 |
|
13 |
+
# Build Vocabulary
|
14 |
+
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
15 |
+
device = "cpu"
|
16 |
+
|
17 |
+
# Build Vocabulary
|
18 |
+
MAX_LENGTH = 15
|
19 |
+
class Vocabulary:
|
20 |
+
"""The Vocabulary class is used to record words, which are used to convert
|
21 |
+
text to numbers and vice versa.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, lang="vi"):
|
25 |
+
self.lang = lang
|
26 |
+
self.word2id = dict()
|
27 |
+
self.word2id["<sos>"] = 0 # Start of Sentence Token
|
28 |
+
self.word2id["<eos>"] = 1 # End of Sentence Token
|
29 |
+
self.word2id["<unk>"] = 2 # Unknown Token
|
30 |
+
self.word2id["<pad>"] = 3 # Pad Token
|
31 |
+
self.sos_id = self.word2id["<sos>"]
|
32 |
+
self.eos_id = self.word2id["<eos>"]
|
33 |
+
self.unk_id = self.word2id["<unk>"]
|
34 |
+
self.pad_id = self.word2id["<pad>"]
|
35 |
+
self.id2word = {v: k for k, v in self.word2id.items()}
|
36 |
+
self.pad_transform = PadTransform(max_length = MAX_LENGTH, pad_value = self.pad_id)
|
37 |
+
|
38 |
+
def __getitem__(self, word):
|
39 |
+
"""Return ID of word if existed else return ID unknown token
|
40 |
+
@param word (str)
|
41 |
+
"""
|
42 |
+
return self.word2id.get(word, self.unk_id)
|
43 |
+
|
44 |
+
def __contains__(self, word):
|
45 |
+
"""Return True if word in Vocabulary else return False
|
46 |
+
@param word (str)
|
47 |
+
"""
|
48 |
+
return word in self.word2id
|
49 |
+
|
50 |
+
def __len__(self):
|
51 |
+
"""
|
52 |
+
Return number of tokens(include sos, eos, unk and pad tokens) in Vocabulary
|
53 |
+
"""
|
54 |
+
return len(self.word2id)
|
55 |
+
|
56 |
+
def lookup_tokens(self, word_indexes: list):
|
57 |
+
"""Return the list of words by lookup by ID
|
58 |
+
@param word_indexes (list(int))
|
59 |
+
@return words (list(str))
|
60 |
+
"""
|
61 |
+
return [self.id2word[word_index] for word_index in word_indexes]
|
62 |
+
|
63 |
+
def add(self, word):
|
64 |
+
"""Add word to vocabulary
|
65 |
+
@param word (str)
|
66 |
+
@return index (str): index of the word just added
|
67 |
+
"""
|
68 |
+
if word not in self:
|
69 |
+
word_index = self.word2id[word] = len(self.word2id)
|
70 |
+
self.id2word[word_index] = word
|
71 |
+
return word_index
|
72 |
+
else:
|
73 |
+
return self[word]
|
74 |
+
|
75 |
+
def preprocessing_sent(self, sent, lang="en"):
|
76 |
+
"""Preprocessing a sentence (depend on language english or vietnamese)
|
77 |
+
@param sent (str)
|
78 |
+
@param lang (str)
|
79 |
+
"""
|
80 |
+
|
81 |
+
# Lowercase sentence and remove space at beginning and ending
|
82 |
+
sent = sent.lower().strip()
|
83 |
+
|
84 |
+
# Remove unnecessary space
|
85 |
+
sent = re.sub("(?<=\w)\.", " .", sent)
|
86 |
+
sent = re.sub("(?<=\w),", " ,", sent)
|
87 |
+
sent = re.sub("(?<=\w)\?", " ?", sent)
|
88 |
+
sent = re.sub("(?<=\w)\!", " !", sent)
|
89 |
+
sent = re.sub(" +", " ", sent)
|
90 |
+
|
91 |
+
if (lang == "en") or (lang == "eng") or (lang == "english"):
|
92 |
+
# Replace short form
|
93 |
+
sent = re.sub("what's", "what is", sent)
|
94 |
+
sent = re.sub("who's", "who is", sent)
|
95 |
+
sent = re.sub("which's", "which is", sent)
|
96 |
+
|
97 |
+
sent = re.sub("i'm", "i am", sent)
|
98 |
+
# Dont know to preprocess with possessive case
|
99 |
+
sent = re.sub("it's", "it is", sent)
|
100 |
+
sent = re.sub("'re ", " are ", sent)
|
101 |
+
sent = re.sub("'ve ", " have ", sent)
|
102 |
+
sent = re.sub("'ll ", " will ", sent)
|
103 |
+
sent = re.sub("'d ", " would ", sent)
|
104 |
+
|
105 |
+
sent = re.sub("aren't", "are not", sent)
|
106 |
+
sent = re.sub("isn't", "is not", sent)
|
107 |
+
sent = re.sub("don't", "do not", sent)
|
108 |
+
sent = re.sub("doesn't", "does not", sent)
|
109 |
+
sent = re.sub("wasn't", "was not", sent)
|
110 |
+
sent = re.sub("weren't", "were not", sent)
|
111 |
+
sent = re.sub("won't", "will not", sent)
|
112 |
+
sent = re.sub("can't", "can not", sent)
|
113 |
+
sent = re.sub("let's", "let us", sent)
|
114 |
+
|
115 |
+
else:
|
116 |
+
# Package underthesea.text_normalize support to normalize vietnamese
|
117 |
+
sent = text_normalize(sent)
|
118 |
+
|
119 |
+
return sent.strip()
|
120 |
+
|
121 |
+
def tokenize_corpus(self, corpus, disable=False):
|
122 |
+
"""Split the documents of the corpus into words
|
123 |
+
@param corpus (list(str)): list of documents
|
124 |
+
@param disable (bool): notified or not
|
125 |
+
@return tokenized_corpus (list(list(str))): list of words
|
126 |
+
"""
|
127 |
+
if not disable:
|
128 |
+
print("Tokenize the corpus...")
|
129 |
+
tokenized_corpus = list()
|
130 |
+
for document in tqdm(corpus, disable=disable):
|
131 |
+
tokenized_document = ["<sos>"] + self.preprocessing_sent(document, self.lang).split(" ") + ["<eos>"]
|
132 |
+
tokenized_corpus.append(tokenized_document)
|
133 |
+
return tokenized_corpus
|
134 |
+
|
135 |
+
def corpus_to_tensor(self, corpus, is_tokenized=False, disable=False):
|
136 |
+
"""Convert corpus to a list of indices tensor
|
137 |
+
@param corpus (list(str) if is_tokenized==False else list(list(str)))
|
138 |
+
@param is_tokenized (bool)
|
139 |
+
@return indicies_corpus (list(tensor))
|
140 |
+
"""
|
141 |
+
if is_tokenized:
|
142 |
+
tokenized_corpus = corpus
|
143 |
+
else:
|
144 |
+
tokenized_corpus = self.tokenize_corpus(corpus, disable=disable)
|
145 |
+
indicies_corpus = list()
|
146 |
+
for document in tqdm(tokenized_corpus, disable=disable):
|
147 |
+
indicies_document = torch.tensor(
|
148 |
+
list(map(lambda word: self[word], document)), dtype=torch.int64
|
149 |
+
)
|
150 |
+
|
151 |
+
indicies_corpus.append(self.pad_transform(indicies_document))
|
152 |
+
|
153 |
+
return indicies_corpus
|
154 |
+
|
155 |
+
def tensor_to_corpus(self, tensor, disable=False):
|
156 |
+
"""Convert list of indices tensor to a list of tokenized documents
|
157 |
+
@param indicies_corpus (list(tensor))
|
158 |
+
@return corpus (list(list(str)))
|
159 |
+
"""
|
160 |
+
corpus = list()
|
161 |
+
for indicies in tqdm(tensor, disable=disable):
|
162 |
+
document = list(map(lambda index: self.id2word[index.item()], indicies))
|
163 |
+
corpus.append(document)
|
164 |
+
|
165 |
+
return corpus
|
166 |
+
|
167 |
+
|
168 |
+
with open("vocab_source.pkl", "rb") as file:
|
169 |
+
VOCAB_SOURCE = pickle.load(file)
|
170 |
+
with open("vocab_target.pkl", "rb") as file:
|
171 |
+
VOCAB_TARGET = pickle.load(file)
|
172 |
+
|
173 |
+
input_embedding = torch.zeros((len(VOCAB_SOURCE), 100))
|
174 |
+
output_embedding = torch.zeros((len(VOCAB_TARGET), 100))
|
175 |
+
|
176 |
+
|
177 |
+
def create_input_emb_layer():
|
178 |
+
num_embeddings, embedding_dim = input_embedding.size()
|
179 |
+
emb_layer = nn.Embedding(num_embeddings, embedding_dim)
|
180 |
+
emb_layer.weight.requires_grad = False
|
181 |
+
return emb_layer, embedding_dim
|
182 |
+
|
183 |
+
def create_output_emb_layer():
|
184 |
+
num_embeddings, embedding_dim = output_embedding.size()
|
185 |
+
emb_layer = nn.Embedding(num_embeddings, embedding_dim)
|
186 |
+
emb_layer.weight.requires_grad = False
|
187 |
+
return emb_layer, embedding_dim
|
188 |
+
|
189 |
+
|
190 |
+
class EncoderRNN(nn.Module):
|
191 |
+
def __init__(self, input_dim, hidden_dim, dropout = 0.1):
|
192 |
+
""" Encoder RNN
|
193 |
+
@param input_dim (int): size of vocab_souce
|
194 |
+
@param hidden_dim (int)
|
195 |
+
@param dropout (float): dropout ratio of layer drop out
|
196 |
+
"""
|
197 |
+
super(EncoderRNN, self).__init__()
|
198 |
+
self.hidden_dim = hidden_dim
|
199 |
+
#self.embedding = nn.Embedding(input_dim, hidden_dim)
|
200 |
+
# Đổi thành input embedding
|
201 |
+
self.embedding, self.embedding_dim = create_input_emb_layer()
|
202 |
+
self.gru = nn.GRU(self.embedding_dim, hidden_dim, batch_first=True)
|
203 |
+
self.dropout = nn.Dropout(dropout)
|
204 |
+
|
205 |
+
def forward(self, src):
|
206 |
+
embedded = self.dropout(self.embedding(src))
|
207 |
+
output, hidden = self.gru(embedded)
|
208 |
+
return output, hidden
|
209 |
+
|
210 |
+
|
211 |
+
class BahdanauAttention(nn.Module):
|
212 |
+
def __init__(self, hidden_size):
|
213 |
+
""" Bahdanau Attention
|
214 |
+
@param hidden_size (int)
|
215 |
+
"""
|
216 |
+
super(BahdanauAttention, self).__init__()
|
217 |
+
self.Wa = nn.Linear(hidden_size, hidden_size)
|
218 |
+
self.Ua = nn.Linear(hidden_size, hidden_size)
|
219 |
+
self.Va = nn.Linear(hidden_size, 1)
|
220 |
+
|
221 |
+
def forward(self, query, keys):
|
222 |
+
scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))
|
223 |
+
scores = scores.squeeze(2).unsqueeze(1)
|
224 |
+
|
225 |
+
weights = F.softmax(scores, dim=-1)
|
226 |
+
context = torch.bmm(weights, keys)
|
227 |
+
|
228 |
+
return context, weights
|
229 |
+
|
230 |
+
class AttnDecoderRNN(nn.Module):
|
231 |
+
def __init__(self, hidden_size, output_size, dropout_p=0.1):
|
232 |
+
""" Decoder RNN using Attention
|
233 |
+
@param hidden_size (int)
|
234 |
+
@param output_size (int): size of vocab_target
|
235 |
+
@param dropout (float): dropout ratio of layer drop out
|
236 |
+
"""
|
237 |
+
super(AttnDecoderRNN, self).__init__()
|
238 |
+
self.embedding, self.embedding_dim = create_output_emb_layer()
|
239 |
+
self.fc = nn.Linear(self.embedding_dim, hidden_size)
|
240 |
+
self.attention = BahdanauAttention(hidden_size)
|
241 |
+
self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)
|
242 |
+
self.out = nn.Linear(hidden_size, output_size)
|
243 |
+
self.dropout = nn.Dropout(dropout_p)
|
244 |
+
|
245 |
+
def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):
|
246 |
+
batch_size = encoder_outputs.size(0)
|
247 |
+
decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(0)
|
248 |
+
decoder_hidden = encoder_hidden
|
249 |
+
decoder_outputs = []
|
250 |
+
attentions = []
|
251 |
+
|
252 |
+
for i in range(MAX_LENGTH):
|
253 |
+
decoder_output, decoder_hidden, attn_weights = self.forward_step(
|
254 |
+
decoder_input, decoder_hidden, encoder_outputs
|
255 |
+
)
|
256 |
+
decoder_outputs.append(decoder_output)
|
257 |
+
attentions.append(attn_weights)
|
258 |
+
|
259 |
+
# Teacher forcing
|
260 |
+
if target_tensor is not None:
|
261 |
+
# Teacher forcing: Feed the target as the next input
|
262 |
+
decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing
|
263 |
+
else:
|
264 |
+
# Without teacher forcing: use its own predictions as the next input
|
265 |
+
_, topi = decoder_output.topk(1)
|
266 |
+
decoder_input = topi.squeeze(-1).detach() # detach from history as input
|
267 |
+
|
268 |
+
decoder_outputs = torch.cat(decoder_outputs, dim=1)
|
269 |
+
decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
|
270 |
+
attentions = torch.cat(attentions, dim=1)
|
271 |
+
|
272 |
+
return decoder_outputs, decoder_hidden, attentions
|
273 |
+
|
274 |
+
|
275 |
+
def forward_step(self, input, hidden, encoder_outputs):
|
276 |
+
embedded = self.dropout(self.fc(self.embedding(input)))
|
277 |
+
|
278 |
+
query = hidden.permute(1, 0, 2)
|
279 |
+
context, attn_weights = self.attention(query, encoder_outputs)
|
280 |
+
input_gru = torch.cat((embedded, context), dim=2)
|
281 |
+
|
282 |
+
output, hidden = self.gru(input_gru, hidden)
|
283 |
+
output = self.out(output)
|
284 |
+
|
285 |
+
return output, hidden, attn_weights
|
286 |
+
|
287 |
+
|
288 |
+
# Load VietAI Translation
|
289 |
envit5_translater = pipeline("translation", model="VietAI/envit5-translation")
|
290 |
|
291 |
+
INPUT_DIM = len(VOCAB_SOURCE)
|
292 |
+
OUTPUT_DIM = len(VOCAB_TARGET)
|
293 |
+
HID_DIM = 512
|
294 |
+
|
295 |
+
# Load our Model Translation
|
296 |
+
ENCODER = EncoderRNN(INPUT_DIM, HID_DIM)
|
297 |
+
ENCODER.load_state_dict(torch.load("hid512_encoder_att_epoch_20.pt"))
|
298 |
+
DECODER = AttnDecoderRNN(HID_DIM, OUTPUT_DIM)
|
299 |
+
DECODER.load_state_dict(torch.load("hid512_decoder_att_epoch_20.pt"))
|
300 |
+
|
301 |
+
|
302 |
+
def evaluate(encoder, decoder, sentence, vocab_source, vocab_target, disable=False):
|
303 |
+
encoder.eval()
|
304 |
+
decoder.eval()
|
305 |
+
with torch.no_grad():
|
306 |
+
input_tensor = (
|
307 |
+
vocab_source.corpus_to_tensor([sentence], disable=disable)[0]
|
308 |
+
.view(1, -1)
|
309 |
+
.to(device)
|
310 |
+
)
|
311 |
+
|
312 |
+
encoder_outputs, encoder_hidden = encoder(input_tensor)
|
313 |
+
decoder_outputs, decoder_hidden, decoder_attn = decoder(
|
314 |
+
encoder_outputs, encoder_hidden
|
315 |
+
)
|
316 |
+
|
317 |
+
_, topi = decoder_outputs.topk(1)
|
318 |
+
decoded_ids = topi.squeeze()
|
319 |
+
|
320 |
+
decoded_words = []
|
321 |
+
for idx in decoded_ids:
|
322 |
+
if idx.item() == vocab_target.eos_id:
|
323 |
+
decoded_words.append("<eos>")
|
324 |
+
break
|
325 |
+
decoded_words.append(vocab_target.id2word[idx.item()])
|
326 |
+
return decoded_words, decoder_attn
|
327 |
+
|
328 |
+
|
329 |
+
def my_translate_model(sentence):
|
330 |
+
output_words, _ = evaluate(
|
331 |
+
ENCODER, DECODER, sentence, VOCAB_SOURCE, VOCAB_TARGET, disable=True
|
332 |
+
)
|
333 |
+
|
334 |
+
return " ".join(output_words[1:-1]).capitalize()
|
335 |
+
|
336 |
+
|
337 |
def envit5_translation(text):
|
338 |
res = envit5_translater(
|
339 |
text,
|
340 |
max_length=512,
|
341 |
early_stopping=True,
|
342 |
+
)[0][
|
343 |
+
"translation_text"
|
344 |
+
][3:]
|
345 |
return res
|
346 |
|
|
|
|
|
|
|
|
|
|
|
347 |
|
348 |
def translation(text):
|
349 |
+
output1 = my_translate_model(text)
|
350 |
output2 = envit5_translation(text)
|
351 |
+
#output3 = finetune_BERT(text)
|
352 |
+
|
353 |
+
return (output1, output2)
|
354 |
+
|
355 |
+
|
356 |
+
examples = [["Input: Hello guys"],
|
357 |
+
["Output: Xin chào các bạn"]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
358 |
|
359 |
demo = gr.Interface(
|
360 |
+
theme = gr.themes.Base(),
|
361 |
fn=translation,
|
362 |
title="Co Gai Mo Duong",
|
363 |
+
description="""
|
364 |
+
## Machine Translation: English to Vietnamese
|
365 |
+
""",
|
366 |
examples=examples,
|
367 |
+
inputs=[
|
368 |
+
gr.Textbox(
|
369 |
+
lines=5, placeholder="Enter text", label="Input"
|
370 |
+
)
|
371 |
+
],
|
372 |
+
outputs=[
|
373 |
+
gr.Textbox(
|
374 |
+
"text", label="Our Machine Translation"
|
375 |
+
),
|
376 |
+
gr.Textbox(
|
377 |
+
"text", label="VietAI Machine Translation"
|
378 |
+
)
|
379 |
+
]
|
380 |
+
)
|
381 |
|
382 |
+
demo.launch(share = True)
|
flagged/log.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
Input,Our Machine Translation,VietAI Machine Translation,flag,username,timestamp
|
2 |
+
Today is a beautiful day.,Hôm nay là một ngày đẹp đẹp, Hôm nay là một ngày đẹp trời.,,,2024-01-11 02:01:36.293799
|
hid512_decoder_att_epoch_20.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2b13f49e00d60a51226db3a66e343ef3b73eccf06e0efe771cac417e1994a706
|
3 |
+
size 40323250
|
hid512_encoder_att_epoch_20.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ec38b650930515f30086a04a16285c88430ceed352cfbd52cc27e34b4283221a
|
3 |
+
size 16096464
|
requirements.txt
CHANGED
@@ -2,4 +2,8 @@ transformers
|
|
2 |
sentencepiece
|
3 |
tokenizers
|
4 |
torch
|
5 |
-
gradio
|
|
|
|
|
|
|
|
|
|
2 |
sentencepiece
|
3 |
tokenizers
|
4 |
torch
|
5 |
+
gradio
|
6 |
+
re
|
7 |
+
pickle
|
8 |
+
torchtext
|
9 |
+
underthesea
|
temp.ipynb
ADDED
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stdout",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"WARNING:tensorflow:From c:\\Users\\THU\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\keras\\src\\losses.py:2976: The name tf.losses.sparse_softmax_cross_entropy is deprecated. Please use tf.compat.v1.losses.sparse_softmax_cross_entropy instead.\n",
|
13 |
+
"\n"
|
14 |
+
]
|
15 |
+
}
|
16 |
+
],
|
17 |
+
"source": [
|
18 |
+
"import gradio as gr\n",
|
19 |
+
"from transformers import pipeline \n",
|
20 |
+
"import re\n",
|
21 |
+
"import pickle \n",
|
22 |
+
"import torch\n",
|
23 |
+
"import torch.nn as nn\n",
|
24 |
+
"from torchtext.transforms import PadTransform\n",
|
25 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
26 |
+
"from torch.nn import functional as F\n",
|
27 |
+
"from tqdm import tqdm\n",
|
28 |
+
"from underthesea import word_tokenize, text_normalize"
|
29 |
+
]
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"cell_type": "code",
|
33 |
+
"execution_count": 7,
|
34 |
+
"metadata": {},
|
35 |
+
"outputs": [
|
36 |
+
{
|
37 |
+
"name": "stdout",
|
38 |
+
"output_type": "stream",
|
39 |
+
"text": [
|
40 |
+
"Running on local URL: http://127.0.0.1:7864\n",
|
41 |
+
"\n",
|
42 |
+
"To create a public link, set `share=True` in `launch()`.\n"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"data": {
|
47 |
+
"text/html": [
|
48 |
+
"<div><iframe src=\"http://127.0.0.1:7864/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
49 |
+
],
|
50 |
+
"text/plain": [
|
51 |
+
"<IPython.core.display.HTML object>"
|
52 |
+
]
|
53 |
+
},
|
54 |
+
"metadata": {},
|
55 |
+
"output_type": "display_data"
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"data": {
|
59 |
+
"text/plain": []
|
60 |
+
},
|
61 |
+
"execution_count": 7,
|
62 |
+
"metadata": {},
|
63 |
+
"output_type": "execute_result"
|
64 |
+
}
|
65 |
+
],
|
66 |
+
"source": [
|
67 |
+
"import gradio as gr\n",
|
68 |
+
"\n",
|
69 |
+
"def translation(text):\n",
|
70 |
+
" output1 = 1\n",
|
71 |
+
" output2 = 2\n",
|
72 |
+
" #output3 = finetune_BERT(text)\n",
|
73 |
+
"\n",
|
74 |
+
" return (output1, output2)\n",
|
75 |
+
"\n",
|
76 |
+
"\n",
|
77 |
+
"\n",
|
78 |
+
"examples = [[\"Input: Hello guys\"], \n",
|
79 |
+
" [\"Output: Xin chào các bạn\"]]\n",
|
80 |
+
"\n",
|
81 |
+
"demo = gr.Interface(\n",
|
82 |
+
" theme = gr.themes.Base(),\n",
|
83 |
+
" fn=translation,\n",
|
84 |
+
" title=\"Co Gai Mo Duong\",\n",
|
85 |
+
" description=\"\"\"\n",
|
86 |
+
" ## Machine Translation: English to Vietnamese\n",
|
87 |
+
" \"\"\",\n",
|
88 |
+
" examples=examples,\n",
|
89 |
+
" inputs=[\n",
|
90 |
+
" gr.Textbox(\n",
|
91 |
+
" lines=5, placeholder=\"Enter text\", label=\"Input\"\n",
|
92 |
+
" )\n",
|
93 |
+
" ],\n",
|
94 |
+
" outputs=[\n",
|
95 |
+
" gr.Textbox(\n",
|
96 |
+
" \"text\", label=\"Our Machine Translation\"\n",
|
97 |
+
" ),\n",
|
98 |
+
" gr.Textbox(\n",
|
99 |
+
" \"text\", label=\"VietAI Machine Translation\"\n",
|
100 |
+
" )\n",
|
101 |
+
" ]\n",
|
102 |
+
")\n",
|
103 |
+
"\n",
|
104 |
+
"demo.launch(shared = True)"
|
105 |
+
]
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"cell_type": "code",
|
109 |
+
"execution_count": 2,
|
110 |
+
"metadata": {},
|
111 |
+
"outputs": [],
|
112 |
+
"source": [
|
113 |
+
"# Build Vocabulary\n",
|
114 |
+
"MAX_LENGTH = 30\n",
|
115 |
+
"#device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
116 |
+
"device = 'cpu'"
|
117 |
+
]
|
118 |
+
},
|
119 |
+
{
|
120 |
+
"cell_type": "code",
|
121 |
+
"execution_count": 3,
|
122 |
+
"metadata": {},
|
123 |
+
"outputs": [],
|
124 |
+
"source": [
|
125 |
+
"class Vocabulary:\n",
|
126 |
+
" \"\"\"The Vocabulary class is used to record words, which are used to convert\n",
|
127 |
+
" text to numbers and vice versa.\n",
|
128 |
+
" \"\"\"\n",
|
129 |
+
"\n",
|
130 |
+
" def __init__(self, lang=\"vi\"):\n",
|
131 |
+
" self.lang = lang\n",
|
132 |
+
" self.word2id = dict()\n",
|
133 |
+
" self.word2id[\"<sos>\"] = 0 # Start of Sentece Token\n",
|
134 |
+
" self.word2id[\"<eos>\"] = 1 # End of Sentence Token\n",
|
135 |
+
" self.word2id[\"<unk>\"] = 2 # Unknown Token\n",
|
136 |
+
" self.word2id[\"<pad>\"] = 3 # Pad Token\n",
|
137 |
+
" self.sos_id = self.word2id[\"<sos>\"]\n",
|
138 |
+
" self.eos_id = self.word2id[\"<eos>\"]\n",
|
139 |
+
" self.unk_id = self.word2id[\"<unk>\"]\n",
|
140 |
+
" self.pad_id = self.word2id[\"<pad>\"]\n",
|
141 |
+
" self.id2word = {v: k for k, v in self.word2id.items()}\n",
|
142 |
+
" self.pad_transform = PadTransform(max_length = MAX_LENGTH, pad_value = self.pad_id)\n",
|
143 |
+
"\n",
|
144 |
+
" def __getitem__(self, word):\n",
|
145 |
+
" \"\"\"Return ID of word if existed else return ID unknown token\n",
|
146 |
+
" @param word (str)\n",
|
147 |
+
" \"\"\"\n",
|
148 |
+
" return self.word2id.get(word, self.unk_id)\n",
|
149 |
+
"\n",
|
150 |
+
" def __contains__(self, word):\n",
|
151 |
+
" \"\"\"Return True if word in Vocabulary else return False\n",
|
152 |
+
" @param word (str)\n",
|
153 |
+
" \"\"\"\n",
|
154 |
+
" return word in self.word2id\n",
|
155 |
+
"\n",
|
156 |
+
" def __len__(self):\n",
|
157 |
+
" \"\"\"\n",
|
158 |
+
" Return number of tokens(include sos, eos, unk and pad tokens) in Vocabulary\n",
|
159 |
+
" \"\"\"\n",
|
160 |
+
" return len(self.word2id)\n",
|
161 |
+
"\n",
|
162 |
+
" def lookup_tokens(self, word_indexes: list):\n",
|
163 |
+
" \"\"\"Return the list of words by lookup by ID\n",
|
164 |
+
" @param word_indexes (list(int))\n",
|
165 |
+
" @return words (list(str))\n",
|
166 |
+
" \"\"\"\n",
|
167 |
+
" return [self.id2word[word_index] for word_index in word_indexes]\n",
|
168 |
+
"\n",
|
169 |
+
" def add(self, word):\n",
|
170 |
+
" \"\"\"Add word to vocabulary\n",
|
171 |
+
" @param word (str)\n",
|
172 |
+
" @return index (str): index of the word just added\n",
|
173 |
+
" \"\"\"\n",
|
174 |
+
" if word not in self:\n",
|
175 |
+
" word_index = self.word2id[word] = len(self.word2id)\n",
|
176 |
+
" self.id2word[word_index] = word\n",
|
177 |
+
" return word_index\n",
|
178 |
+
" else:\n",
|
179 |
+
" return self[word]\n",
|
180 |
+
"\n",
|
181 |
+
" def preprocessing_sent(self, sent, lang=\"en\"):\n",
|
182 |
+
" \"\"\"Preprocessing a sentence (depend on language english or vietnamese)\"\"\"\n",
|
183 |
+
"\n",
|
184 |
+
" if (lang == \"en\") or (lang == \"eng\") or (lang == \"english\"):\n",
|
185 |
+
" # Remove unnecessary space\n",
|
186 |
+
" sent = re.sub(\" +\", \" \", sent)\n",
|
187 |
+
"\n",
|
188 |
+
" # Replace short form\n",
|
189 |
+
" sent = re.sub(\"'m \", \"am \", sent)\n",
|
190 |
+
" # Dont know to preprocess with possessive case\n",
|
191 |
+
" sent = re.sub(\"'s \", \"is \", sent)\n",
|
192 |
+
" sent = re.sub(\"'re \", \"are \", sent)\n",
|
193 |
+
" sent = re.sub(\"'ve \", \"have \", sent)\n",
|
194 |
+
" sent = re.sub(\"'ll \", \"will \", sent)\n",
|
195 |
+
" sent = re.sub(\"'d \", \"would \", sent)\n",
|
196 |
+
"\n",
|
197 |
+
" sent = re.sub(\"aren 't\", \"are not\", sent)\n",
|
198 |
+
" sent = re.sub(\"isn 't\", \"is not\", sent)\n",
|
199 |
+
" sent = re.sub(\"don 't\", \"do not\", sent)\n",
|
200 |
+
" sent = re.sub(\"doesn 't\", \"does not\", sent)\n",
|
201 |
+
" sent = re.sub(\"wasn 't\", \"was not\", sent)\n",
|
202 |
+
" sent = re.sub(\"weren 't\", \"were not\", sent)\n",
|
203 |
+
" sent = re.sub(\"won 't\", \"will not\", sent)\n",
|
204 |
+
" sent = re.sub(\"can 't\", \"can not\", sent)\n",
|
205 |
+
" sent = re.sub(\"let 's\", \"let us\", sent)\n",
|
206 |
+
"\n",
|
207 |
+
" else:\n",
|
208 |
+
" # Package underthesea.text_normalize support to normalize vietnamese\n",
|
209 |
+
" sent = text_normalize(sent)\n",
|
210 |
+
"\n",
|
211 |
+
" sent = re.sub(\"'\", \"'\", sent)\n",
|
212 |
+
" sent = re.sub(\""\", '\"', sent)\n",
|
213 |
+
" sent = re.sub(\"[\", \"[\", sent)\n",
|
214 |
+
" sent = re.sub(\"]\", \"]\", sent)\n",
|
215 |
+
" \n",
|
216 |
+
" # Lowercase sentence and remove space at beginning and ending\n",
|
217 |
+
" return sent.lower().strip()\n",
|
218 |
+
"\n",
|
219 |
+
" def tokenize_corpus(self, corpus, disable=False):\n",
|
220 |
+
" \"\"\"Split the documents of the corpus into words\n",
|
221 |
+
" @param corpus (list(str)): list of documents\n",
|
222 |
+
" @return tokenized_corpus (list(list(str))): list of words\n",
|
223 |
+
" \"\"\"\n",
|
224 |
+
" if not disable:\n",
|
225 |
+
" print(\"Tokenize the corpus...\")\n",
|
226 |
+
" tokenized_corpus = list()\n",
|
227 |
+
" for document in tqdm(corpus, disable=disable):\n",
|
228 |
+
" tokenized_document = [\"<sos>\"] + self.preprocessing_sent(document).split(\" \") + [\"<eos>\"]\n",
|
229 |
+
" tokenized_corpus.append(tokenized_document)\n",
|
230 |
+
" return tokenized_corpus\n",
|
231 |
+
"\n",
|
232 |
+
" def corpus_to_tensor(self, corpus, is_tokenized=False, disable=False):\n",
|
233 |
+
" \"\"\"Convert corpus to a list of indices tensor\n",
|
234 |
+
" @param corpus (list(str) if is_tokenized==False else list(list(str)))\n",
|
235 |
+
" @param is_tokenized (bool)\n",
|
236 |
+
" @return indicies_corpus (list(tensor))\n",
|
237 |
+
" \"\"\"\n",
|
238 |
+
" if is_tokenized:\n",
|
239 |
+
" tokenized_corpus = corpus\n",
|
240 |
+
" else:\n",
|
241 |
+
" tokenized_corpus = self.tokenize_corpus(corpus, disable=disable)\n",
|
242 |
+
" indicies_corpus = list()\n",
|
243 |
+
" for document in tqdm(tokenized_corpus, disable=disable):\n",
|
244 |
+
" indicies_document = torch.tensor(\n",
|
245 |
+
" list(map(lambda word: self[word], document)), dtype=torch.int64\n",
|
246 |
+
" )\n",
|
247 |
+
" \n",
|
248 |
+
" indicies_corpus.append(self.pad_transform(indicies_document))\n",
|
249 |
+
"\n",
|
250 |
+
" return indicies_corpus\n",
|
251 |
+
"\n",
|
252 |
+
" def tensor_to_corpus(self, tensor, disable=False):\n",
|
253 |
+
" \"\"\"Convert list of indices tensor to a list of tokenized documents\n",
|
254 |
+
" @param indicies_corpus (list(tensor))\n",
|
255 |
+
" @return corpus (list(list(str)))\n",
|
256 |
+
" \"\"\"\n",
|
257 |
+
" corpus = list()\n",
|
258 |
+
" for indicies in tqdm(tensor, disable=disable):\n",
|
259 |
+
" document = list(map(lambda index: self.id2word[index.item()], indicies))\n",
|
260 |
+
" corpus.append(document)\n",
|
261 |
+
"\n",
|
262 |
+
" return corpus"
|
263 |
+
]
|
264 |
+
},
|
265 |
+
{
|
266 |
+
"cell_type": "code",
|
267 |
+
"execution_count": 4,
|
268 |
+
"metadata": {},
|
269 |
+
"outputs": [],
|
270 |
+
"source": [
|
271 |
+
"def create_input_emb_layer():\n",
|
272 |
+
" num_embeddings, embedding_dim = 32998, 100\n",
|
273 |
+
" emb_layer = nn.Embedding(num_embeddings, embedding_dim)\n",
|
274 |
+
" emb_layer.weight.requires_grad = False\n",
|
275 |
+
"\n",
|
276 |
+
" return emb_layer, embedding_dim\n",
|
277 |
+
"\n",
|
278 |
+
"def create_output_emb_layer():\n",
|
279 |
+
" num_embeddings, embedding_dim = 15405, 100\n",
|
280 |
+
" emb_layer = nn.Embedding(num_embeddings, embedding_dim)\n",
|
281 |
+
" emb_layer.weight.requires_grad = False\n",
|
282 |
+
"\n",
|
283 |
+
" return emb_layer, embedding_dim\n",
|
284 |
+
" \n",
|
285 |
+
"class EncoderRNN(nn.Module):\n",
|
286 |
+
" def __init__(self, input_dim, hidden_dim, dropout = 0.2):\n",
|
287 |
+
" super(EncoderRNN, self).__init__()\n",
|
288 |
+
" \n",
|
289 |
+
" self.hidden_dim = hidden_dim\n",
|
290 |
+
" #self.embedding = nn.Embedding(input_dim, hidden_dim)\n",
|
291 |
+
" # Đổi thành input embedding\n",
|
292 |
+
" self.embedding, self.embedding_dim = create_input_emb_layer()\n",
|
293 |
+
" self.gru = nn.GRU(self.embedding_dim, hidden_dim, batch_first=True)\n",
|
294 |
+
" self.dropout = nn.Dropout(dropout)\n",
|
295 |
+
"\n",
|
296 |
+
" def forward(self, src):\n",
|
297 |
+
" embedded = self.dropout(self.embedding(src))\n",
|
298 |
+
" output, hidden = self.gru(embedded)\n",
|
299 |
+
" return output, hidden\n",
|
300 |
+
" \n",
|
301 |
+
"class BahdanauAttention(nn.Module):\n",
|
302 |
+
" def __init__(self, hidden_size):\n",
|
303 |
+
" super(BahdanauAttention, self).__init__()\n",
|
304 |
+
" self.Wa = nn.Linear(hidden_size, hidden_size)\n",
|
305 |
+
" self.Ua = nn.Linear(hidden_size, hidden_size)\n",
|
306 |
+
" self.Va = nn.Linear(hidden_size, 1)\n",
|
307 |
+
"\n",
|
308 |
+
" def forward(self, query, keys):\n",
|
309 |
+
" scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))\n",
|
310 |
+
" scores = scores.squeeze(2).unsqueeze(1)\n",
|
311 |
+
"\n",
|
312 |
+
" weights = F.softmax(scores, dim=-1)\n",
|
313 |
+
" context = torch.bmm(weights, keys)\n",
|
314 |
+
"\n",
|
315 |
+
" return context, weights\n",
|
316 |
+
"\n",
|
317 |
+
"class AttnDecoderRNN(nn.Module):\n",
|
318 |
+
" def __init__(self, hidden_size, output_size, dropout_p=0.1):\n",
|
319 |
+
" super(AttnDecoderRNN, self).__init__()\n",
|
320 |
+
" # self.embedding = nn.Embedding(output_size, hidden_size)\n",
|
321 |
+
" # Đổi thành output embedding\n",
|
322 |
+
" self.embedding, self.embedding_dim = create_output_emb_layer()\n",
|
323 |
+
" self.fc = nn.Linear(self.embedding_dim, hidden_size)\n",
|
324 |
+
" self.attention = BahdanauAttention(hidden_size)\n",
|
325 |
+
" self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)\n",
|
326 |
+
" self.out = nn.Linear(hidden_size, output_size)\n",
|
327 |
+
" self.dropout = nn.Dropout(dropout_p)\n",
|
328 |
+
"\n",
|
329 |
+
" def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
|
330 |
+
" batch_size = encoder_outputs.size(0)\n",
|
331 |
+
" decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(0)\n",
|
332 |
+
" decoder_hidden = encoder_hidden\n",
|
333 |
+
" decoder_outputs = []\n",
|
334 |
+
" attentions = []\n",
|
335 |
+
"\n",
|
336 |
+
" for i in range(MAX_LENGTH):\n",
|
337 |
+
" decoder_output, decoder_hidden, attn_weights = self.forward_step(\n",
|
338 |
+
" decoder_input, decoder_hidden, encoder_outputs\n",
|
339 |
+
" )\n",
|
340 |
+
" decoder_outputs.append(decoder_output)\n",
|
341 |
+
" attentions.append(attn_weights)\n",
|
342 |
+
"\n",
|
343 |
+
" if target_tensor is not None:\n",
|
344 |
+
" # Teacher forcing: Feed the target as the next input\n",
|
345 |
+
" decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing\n",
|
346 |
+
" else:\n",
|
347 |
+
" # Without teacher forcing: use its own predictions as the next input\n",
|
348 |
+
" _, topi = decoder_output.topk(1)\n",
|
349 |
+
" decoder_input = topi.squeeze(-1).detach() # detach from history as input\n",
|
350 |
+
"\n",
|
351 |
+
" decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
|
352 |
+
" decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
|
353 |
+
" attentions = torch.cat(attentions, dim=1)\n",
|
354 |
+
"\n",
|
355 |
+
" return decoder_outputs, decoder_hidden, attentions\n",
|
356 |
+
"\n",
|
357 |
+
"\n",
|
358 |
+
" def forward_step(self, input, hidden, encoder_outputs):\n",
|
359 |
+
" embedded = self.dropout(self.fc(self.embedding(input)))\n",
|
360 |
+
" \n",
|
361 |
+
" query = hidden.permute(1, 0, 2)\n",
|
362 |
+
" context, attn_weights = self.attention(query, encoder_outputs)\n",
|
363 |
+
" input_gru = torch.cat((embedded, context), dim=2)\n",
|
364 |
+
"\n",
|
365 |
+
" output, hidden = self.gru(input_gru, hidden)\n",
|
366 |
+
" output = self.out(output)\n",
|
367 |
+
"\n",
|
368 |
+
" return output, hidden, attn_weights"
|
369 |
+
]
|
370 |
+
},
|
371 |
+
{
|
372 |
+
"cell_type": "code",
|
373 |
+
"execution_count": null,
|
374 |
+
"metadata": {},
|
375 |
+
"outputs": [],
|
376 |
+
"source": []
|
377 |
+
},
|
378 |
+
{
|
379 |
+
"cell_type": "code",
|
380 |
+
"execution_count": 41,
|
381 |
+
"metadata": {},
|
382 |
+
"outputs": [
|
383 |
+
{
|
384 |
+
"data": {
|
385 |
+
"text/plain": [
|
386 |
+
"<All keys matched successfully>"
|
387 |
+
]
|
388 |
+
},
|
389 |
+
"execution_count": 41,
|
390 |
+
"metadata": {},
|
391 |
+
"output_type": "execute_result"
|
392 |
+
}
|
393 |
+
],
|
394 |
+
"source": [
|
395 |
+
"with open(\"vocab_source.pkl\", \"rb\") as file:\n",
|
396 |
+
" VOCAB_SOURCE = pickle.load(file)\n",
|
397 |
+
"with open(\"vocab_target.pkl\", \"rb\") as file:\n",
|
398 |
+
" VOCAB_TARGET = pickle.load(file)\n",
|
399 |
+
"\n",
|
400 |
+
"INPUT_DIM = len(VOCAB_SOURCE)\n",
|
401 |
+
"OUTPUT_DIM = len(VOCAB_TARGET)\n",
|
402 |
+
"HID_DIM = 512\n",
|
403 |
+
"\n",
|
404 |
+
"# Load our Model Translation\n",
|
405 |
+
"ENCODER = EncoderRNN(INPUT_DIM, HID_DIM)\n",
|
406 |
+
"ENCODER.load_state_dict(torch.load('encoder_att_epoch_16.pt'))\n",
|
407 |
+
"DECODER = AttnDecoderRNN(HID_DIM, OUTPUT_DIM)\n",
|
408 |
+
"DECODER.load_state_dict(torch.load('decoder_att_epoch_16.pt'))"
|
409 |
+
]
|
410 |
+
},
|
411 |
+
{
|
412 |
+
"cell_type": "code",
|
413 |
+
"execution_count": 42,
|
414 |
+
"metadata": {},
|
415 |
+
"outputs": [],
|
416 |
+
"source": [
|
417 |
+
"def evaluate(encoder, decoder, sentence, vocab_source, vocab_target, disable = False):\n",
|
418 |
+
" encoder.eval()\n",
|
419 |
+
" decoder.eval()\n",
|
420 |
+
" with torch.no_grad():\n",
|
421 |
+
" input_tensor = vocab_source.corpus_to_tensor([sentence], disable = disable)[0].view(1,-1).to(device)\n",
|
422 |
+
" \n",
|
423 |
+
" encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
|
424 |
+
" decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)\n",
|
425 |
+
"\n",
|
426 |
+
" _, topi = decoder_outputs.topk(1)\n",
|
427 |
+
" decoded_ids = topi.squeeze()\n",
|
428 |
+
"\n",
|
429 |
+
" decoded_words = []\n",
|
430 |
+
" for idx in decoded_ids:\n",
|
431 |
+
" if idx.item() == vocab_target.eos_id:\n",
|
432 |
+
" decoded_words.append('<eos>')\n",
|
433 |
+
" break\n",
|
434 |
+
" decoded_words.append(vocab_target.id2word[idx.item()])\n",
|
435 |
+
" return decoded_words, decoder_attn\n",
|
436 |
+
"\n",
|
437 |
+
"def my_translate_model(sentence):\n",
|
438 |
+
" output_words, _ = evaluate(ENCODER, DECODER, sentence, VOCAB_SOURCE, VOCAB_TARGET, disable= True)\n",
|
439 |
+
" \n",
|
440 |
+
" return ' '.join(output_words[1:-1]).capitalize()+ '.'"
|
441 |
+
]
|
442 |
+
},
|
443 |
+
{
|
444 |
+
"cell_type": "code",
|
445 |
+
"execution_count": 61,
|
446 |
+
"metadata": {},
|
447 |
+
"outputs": [
|
448 |
+
{
|
449 |
+
"data": {
|
450 |
+
"text/plain": [
|
451 |
+
"'Tôi hy vọng các bạn sẽ có thể làm được giải pháp.'"
|
452 |
+
]
|
453 |
+
},
|
454 |
+
"execution_count": 61,
|
455 |
+
"metadata": {},
|
456 |
+
"output_type": "execute_result"
|
457 |
+
}
|
458 |
+
],
|
459 |
+
"source": [
|
460 |
+
"my_translate_model(\"I hope you will be better\")"
|
461 |
+
]
|
462 |
+
},
|
463 |
+
{
|
464 |
+
"cell_type": "code",
|
465 |
+
"execution_count": 60,
|
466 |
+
"metadata": {},
|
467 |
+
"outputs": [
|
468 |
+
{
|
469 |
+
"data": {
|
470 |
+
"text/plain": [
|
471 |
+
"<All keys matched successfully>"
|
472 |
+
]
|
473 |
+
},
|
474 |
+
"execution_count": 60,
|
475 |
+
"metadata": {},
|
476 |
+
"output_type": "execute_result"
|
477 |
+
}
|
478 |
+
],
|
479 |
+
"source": [
|
480 |
+
"ENCODER = EncoderRNN(INPUT_DIM, HID_DIM)\n",
|
481 |
+
"ENCODER.load_state_dict(torch.load('encoder_att_epoch_16.pt'))\n",
|
482 |
+
"DECODER = AttnDecoderRNN(HID_DIM, OUTPUT_DIM)\n",
|
483 |
+
"DECODER.load_state_dict(torch.load('decoder_att_epoch_16.pt'))"
|
484 |
+
]
|
485 |
+
},
|
486 |
+
{
|
487 |
+
"cell_type": "code",
|
488 |
+
"execution_count": 48,
|
489 |
+
"metadata": {},
|
490 |
+
"outputs": [
|
491 |
+
{
|
492 |
+
"data": {
|
493 |
+
"text/plain": [
|
494 |
+
"odict_keys(['embedding.weight', 'fc.weight', 'fc.bias', 'attention.Wa.weight', 'attention.Wa.bias', 'attention.Ua.weight', 'attention.Ua.bias', 'attention.Va.weight', 'attention.Va.bias', 'gru.weight_ih_l0', 'gru.weight_hh_l0', 'gru.bias_ih_l0', 'gru.bias_hh_l0', 'out.weight', 'out.bias'])"
|
495 |
+
]
|
496 |
+
},
|
497 |
+
"execution_count": 48,
|
498 |
+
"metadata": {},
|
499 |
+
"output_type": "execute_result"
|
500 |
+
}
|
501 |
+
],
|
502 |
+
"source": [
|
503 |
+
"torch.load('decoder_att_epoch_16.pt').keys()"
|
504 |
+
]
|
505 |
+
},
|
506 |
+
{
|
507 |
+
"cell_type": "code",
|
508 |
+
"execution_count": 52,
|
509 |
+
"metadata": {},
|
510 |
+
"outputs": [
|
511 |
+
{
|
512 |
+
"data": {
|
513 |
+
"text/plain": [
|
514 |
+
"<All keys matched successfully>"
|
515 |
+
]
|
516 |
+
},
|
517 |
+
"execution_count": 52,
|
518 |
+
"metadata": {},
|
519 |
+
"output_type": "execute_result"
|
520 |
+
}
|
521 |
+
],
|
522 |
+
"source": [
|
523 |
+
"DECODER.load_state_dict(torch.load('decoder_att_epoch_16.pt'))"
|
524 |
+
]
|
525 |
+
},
|
526 |
+
{
|
527 |
+
"cell_type": "code",
|
528 |
+
"execution_count": 57,
|
529 |
+
"metadata": {},
|
530 |
+
"outputs": [
|
531 |
+
{
|
532 |
+
"data": {
|
533 |
+
"text/plain": [
|
534 |
+
"<All keys matched successfully>"
|
535 |
+
]
|
536 |
+
},
|
537 |
+
"execution_count": 57,
|
538 |
+
"metadata": {},
|
539 |
+
"output_type": "execute_result"
|
540 |
+
}
|
541 |
+
],
|
542 |
+
"source": [
|
543 |
+
"DECODER = AttnDecoderRNN(HID_DIM, OUTPUT_DIM)\n",
|
544 |
+
"DECODER.load_state_dict(torch.load('decoder_att_epoch_16.pt'))"
|
545 |
+
]
|
546 |
+
}
|
547 |
+
],
|
548 |
+
"metadata": {
|
549 |
+
"kernelspec": {
|
550 |
+
"display_name": "Python 3",
|
551 |
+
"language": "python",
|
552 |
+
"name": "python3"
|
553 |
+
},
|
554 |
+
"language_info": {
|
555 |
+
"codemirror_mode": {
|
556 |
+
"name": "ipython",
|
557 |
+
"version": 3
|
558 |
+
},
|
559 |
+
"file_extension": ".py",
|
560 |
+
"mimetype": "text/x-python",
|
561 |
+
"name": "python",
|
562 |
+
"nbconvert_exporter": "python",
|
563 |
+
"pygments_lexer": "ipython3",
|
564 |
+
"version": "3.11.5"
|
565 |
+
}
|
566 |
+
},
|
567 |
+
"nbformat": 4,
|
568 |
+
"nbformat_minor": 2
|
569 |
+
}
|
vocab_source.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cf38f3daacf3feb3b80cba2069210d5ac3b770c232233178f42434b709bba360
|
3 |
+
size 659103
|
vocab_target.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ac7bd376478b2b3bbcfbeeccd5ced630340b95d3da5eab8d7c1c9e01d74b50d2
|
3 |
+
size 228271
|