|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
import random |
|
import enum |
|
import traceback |
|
|
|
import os |
|
import sys |
|
import json |
|
|
|
F_DIR = os.path.dirname(os.path.realpath(__file__)) |
|
|
|
|
|
class XlitError(enum.Enum): |
|
lang_err = "Unsupported langauge ID requested ;( Please check available languages." |
|
string_err = "String passed is incompatable ;(" |
|
internal_err = "Internal crash ;(" |
|
unknown_err = "Unknown Failure" |
|
loading_err = "Loading failed ;( Check if metadata/paths are correctly configured." |
|
|
|
|
|
class Encoder(nn.Module): |
|
""" |
|
Simple RNN based encoder network |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_dim, |
|
embed_dim, |
|
hidden_dim, |
|
rnn_type="gru", |
|
layers=1, |
|
bidirectional=False, |
|
dropout=0, |
|
device="cpu", |
|
): |
|
super(Encoder, self).__init__() |
|
|
|
self.input_dim = input_dim |
|
self.enc_embed_dim = embed_dim |
|
self.enc_hidden_dim = hidden_dim |
|
self.enc_rnn_type = rnn_type |
|
self.enc_layers = layers |
|
self.enc_directions = 2 if bidirectional else 1 |
|
self.device = device |
|
|
|
self.embedding = nn.Embedding(self.input_dim, self.enc_embed_dim) |
|
|
|
if self.enc_rnn_type == "gru": |
|
self.enc_rnn = nn.GRU( |
|
input_size=self.enc_embed_dim, |
|
hidden_size=self.enc_hidden_dim, |
|
num_layers=self.enc_layers, |
|
bidirectional=bidirectional, |
|
) |
|
elif self.enc_rnn_type == "lstm": |
|
self.enc_rnn = nn.LSTM( |
|
input_size=self.enc_embed_dim, |
|
hidden_size=self.enc_hidden_dim, |
|
num_layers=self.enc_layers, |
|
bidirectional=bidirectional, |
|
) |
|
else: |
|
raise Exception("unknown RNN type mentioned") |
|
|
|
def forward(self, x, x_sz, hidden=None): |
|
""" |
|
x_sz: (batch_size, 1) - Unpadded sequence lengths used for pack_pad |
|
|
|
Return: |
|
output: (batch_size, max_length, hidden_dim) |
|
hidden: (n_layer*num_directions, batch_size, hidden_dim) | if LSTM tuple -(h_n, c_n) |
|
|
|
""" |
|
batch_sz = x.shape[0] |
|
|
|
x = self.embedding(x) |
|
|
|
|
|
|
|
x = x.permute(1, 0, 2) |
|
x = nn.utils.rnn.pack_padded_sequence(x, x_sz, enforce_sorted=False) |
|
|
|
|
|
|
|
output, hidden = self.enc_rnn(x) |
|
|
|
|
|
|
|
output, _ = nn.utils.rnn.pad_packed_sequence(output) |
|
|
|
|
|
output = output.permute(1, 0, 2) |
|
|
|
return output, hidden |
|
|
|
|
|
class Decoder(nn.Module): |
|
""" |
|
Used as decoder stage |
|
""" |
|
|
|
def __init__( |
|
self, |
|
output_dim, |
|
embed_dim, |
|
hidden_dim, |
|
rnn_type="gru", |
|
layers=1, |
|
use_attention=True, |
|
enc_outstate_dim=None, |
|
dropout=0, |
|
device="cpu", |
|
): |
|
super(Decoder, self).__init__() |
|
|
|
self.output_dim = output_dim |
|
self.dec_hidden_dim = hidden_dim |
|
self.dec_embed_dim = embed_dim |
|
self.dec_rnn_type = rnn_type |
|
self.dec_layers = layers |
|
self.use_attention = use_attention |
|
self.device = device |
|
if self.use_attention: |
|
self.enc_outstate_dim = enc_outstate_dim if enc_outstate_dim else hidden_dim |
|
else: |
|
self.enc_outstate_dim = 0 |
|
|
|
self.embedding = nn.Embedding(self.output_dim, self.dec_embed_dim) |
|
|
|
if self.dec_rnn_type == "gru": |
|
self.dec_rnn = nn.GRU( |
|
input_size=self.dec_embed_dim |
|
+ self.enc_outstate_dim, |
|
hidden_size=self.dec_hidden_dim, |
|
num_layers=self.dec_layers, |
|
batch_first=True, |
|
) |
|
elif self.dec_rnn_type == "lstm": |
|
self.dec_rnn = nn.LSTM( |
|
input_size=self.dec_embed_dim |
|
+ self.enc_outstate_dim, |
|
hidden_size=self.dec_hidden_dim, |
|
num_layers=self.dec_layers, |
|
batch_first=True, |
|
) |
|
else: |
|
raise Exception("unknown RNN type mentioned") |
|
|
|
self.fc = nn.Sequential( |
|
nn.Linear(self.dec_hidden_dim, self.dec_embed_dim), |
|
nn.LeakyReLU(), |
|
|
|
nn.Linear(self.dec_embed_dim, self.output_dim), |
|
) |
|
|
|
|
|
if self.use_attention: |
|
self.W1 = nn.Linear(self.enc_outstate_dim, self.dec_hidden_dim) |
|
self.W2 = nn.Linear(self.dec_hidden_dim, self.dec_hidden_dim) |
|
self.V = nn.Linear(self.dec_hidden_dim, 1) |
|
|
|
def attention(self, x, hidden, enc_output): |
|
""" |
|
x: (batch_size, 1, dec_embed_dim) -> after Embedding |
|
enc_output: batch_size, max_length, enc_hidden_dim *num_directions |
|
hidden: n_layers, batch_size, hidden_size | if LSTM (h_n, c_n) |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
hidden_with_time_axis = torch.sum(hidden, axis=0) |
|
|
|
hidden_with_time_axis = hidden_with_time_axis.unsqueeze(1) |
|
|
|
|
|
score = torch.tanh(self.W1(enc_output) + self.W2(hidden_with_time_axis)) |
|
|
|
|
|
|
|
attention_weights = torch.softmax(self.V(score), dim=1) |
|
|
|
|
|
context_vector = attention_weights * enc_output |
|
context_vector = torch.sum(context_vector, dim=1) |
|
|
|
context_vector = context_vector.unsqueeze(1) |
|
|
|
|
|
attend_out = torch.cat((context_vector, x), -1) |
|
|
|
return attend_out, attention_weights |
|
|
|
def forward(self, x, hidden, enc_output): |
|
""" |
|
x: (batch_size, 1) |
|
enc_output: batch_size, max_length, dec_embed_dim |
|
hidden: n_layer, batch_size, hidden_size | lstm: (h_n, c_n) |
|
""" |
|
if (hidden is None) and (self.use_attention is False): |
|
raise Exception("No use of a decoder with No attention and No Hidden") |
|
|
|
batch_sz = x.shape[0] |
|
|
|
if hidden is None: |
|
|
|
hid_for_att = torch.zeros( |
|
(self.dec_layers, batch_sz, self.dec_hidden_dim) |
|
).to(self.device) |
|
elif self.dec_rnn_type == "lstm": |
|
hid_for_att = hidden[0] |
|
else: |
|
hid_for_att = hidden |
|
|
|
|
|
x = self.embedding(x) |
|
|
|
if self.use_attention: |
|
|
|
|
|
x, aw = self.attention(x, hid_for_att, enc_output) |
|
else: |
|
x, aw = x, 0 |
|
|
|
|
|
|
|
|
|
output, hidden = ( |
|
self.dec_rnn(x, hidden) if hidden is not None else self.dec_rnn(x) |
|
) |
|
|
|
|
|
output = output.view(-1, output.size(2)) |
|
|
|
|
|
output = self.fc(output) |
|
|
|
return output, hidden, aw |
|
|
|
|
|
class Seq2Seq(nn.Module): |
|
""" |
|
Used to construct seq2seq architecture with encoder decoder objects |
|
""" |
|
|
|
def __init__( |
|
self, encoder, decoder, pass_enc2dec_hid=False, dropout=0, device="cpu" |
|
): |
|
super(Seq2Seq, self).__init__() |
|
|
|
self.encoder = encoder |
|
self.decoder = decoder |
|
self.device = device |
|
self.pass_enc2dec_hid = pass_enc2dec_hid |
|
|
|
if self.pass_enc2dec_hid: |
|
assert ( |
|
decoder.dec_hidden_dim == encoder.enc_hidden_dim |
|
), "Hidden Dimension of encoder and decoder must be same, or unset `pass_enc2dec_hid`" |
|
if decoder.use_attention: |
|
assert ( |
|
decoder.enc_outstate_dim |
|
== encoder.enc_directions * encoder.enc_hidden_dim |
|
), "Set `enc_out_dim` correctly in decoder" |
|
assert ( |
|
self.pass_enc2dec_hid or decoder.use_attention |
|
), "No use of a decoder with No attention and No Hidden from Encoder" |
|
|
|
def forward(self, src, tgt, src_sz, teacher_forcing_ratio=0): |
|
""" |
|
src: (batch_size, sequence_len.padded) |
|
tgt: (batch_size, sequence_len.padded) |
|
src_sz: [batch_size, 1] - Unpadded sequence lengths |
|
""" |
|
batch_size = tgt.shape[0] |
|
|
|
|
|
|
|
enc_output, enc_hidden = self.encoder(src, src_sz) |
|
|
|
if self.pass_enc2dec_hid: |
|
|
|
dec_hidden = enc_hidden |
|
else: |
|
|
|
dec_hidden = None |
|
|
|
|
|
pred_vecs = torch.zeros(batch_size, self.decoder.output_dim, tgt.size(1)).to( |
|
self.device |
|
) |
|
|
|
|
|
dec_input = tgt[:, 0].unsqueeze(1) |
|
pred_vecs[:, 1, 0] = 1 |
|
for t in range(1, tgt.size(1)): |
|
|
|
|
|
|
|
dec_output, dec_hidden, _ = self.decoder( |
|
dec_input, |
|
dec_hidden, |
|
enc_output, |
|
) |
|
pred_vecs[:, :, t] = dec_output |
|
|
|
|
|
prediction = torch.argmax(dec_output, dim=1) |
|
|
|
|
|
if random.random() < teacher_forcing_ratio: |
|
dec_input = tgt[:, t].unsqueeze(1) |
|
else: |
|
dec_input = prediction.unsqueeze(1) |
|
|
|
return pred_vecs |
|
|
|
def inference(self, src, max_tgt_sz=50, debug=0): |
|
""" |
|
single input only, No batch Inferencing |
|
src: (sequence_len) |
|
debug: if True will return attention weights also |
|
""" |
|
batch_size = 1 |
|
start_tok = src[0] |
|
end_tok = src[-1] |
|
src_sz = torch.tensor([len(src)]) |
|
src_ = src.unsqueeze(0) |
|
|
|
|
|
|
|
enc_output, enc_hidden = self.encoder(src_, src_sz) |
|
|
|
if self.pass_enc2dec_hid: |
|
|
|
dec_hidden = enc_hidden |
|
else: |
|
|
|
dec_hidden = None |
|
|
|
|
|
pred_arr = torch.zeros(max_tgt_sz, 1).to(self.device) |
|
if debug: |
|
attend_weight_arr = torch.zeros(max_tgt_sz, len(src)).to(self.device) |
|
|
|
|
|
dec_input = start_tok.view(1, 1) |
|
pred_arr[0] = start_tok.view(1, 1) |
|
for t in range(max_tgt_sz): |
|
|
|
|
|
|
|
dec_output, dec_hidden, aw = self.decoder( |
|
dec_input, |
|
dec_hidden, |
|
enc_output, |
|
) |
|
|
|
prediction = torch.argmax(dec_output, dim=1) |
|
dec_input = prediction.unsqueeze(1) |
|
pred_arr[t] = prediction |
|
if debug: |
|
attend_weight_arr[t] = aw.squeeze(-1) |
|
|
|
if torch.eq(prediction, end_tok): |
|
break |
|
|
|
if debug: |
|
return pred_arr.squeeze(), attend_weight_arr |
|
|
|
return pred_arr.squeeze().to(dtype=torch.long) |
|
|
|
def active_beam_inference(self, src, beam_width=3, max_tgt_sz=50): |
|
"""Active beam Search based decoding |
|
src: (sequence_len) |
|
""" |
|
|
|
def _avg_score(p_tup): |
|
"""Used for Sorting |
|
TODO: Dividing by length of sequence power alpha as hyperparam |
|
""" |
|
return p_tup[0] |
|
|
|
batch_size = 1 |
|
start_tok = src[0] |
|
end_tok = src[-1] |
|
src_sz = torch.tensor([len(src)]) |
|
src_ = src.unsqueeze(0) |
|
|
|
|
|
|
|
enc_output, enc_hidden = self.encoder(src_, src_sz) |
|
|
|
if self.pass_enc2dec_hid: |
|
|
|
init_dec_hidden = enc_hidden |
|
else: |
|
|
|
init_dec_hidden = None |
|
|
|
|
|
|
|
|
|
top_pred_list = [(0, start_tok.unsqueeze(0), init_dec_hidden)] |
|
|
|
for t in range(max_tgt_sz): |
|
cur_pred_list = [] |
|
|
|
for p_tup in top_pred_list: |
|
if p_tup[1][-1] == end_tok: |
|
cur_pred_list.append(p_tup) |
|
continue |
|
|
|
|
|
|
|
dec_output, dec_hidden, _ = self.decoder( |
|
x=p_tup[1][-1].view(1, 1), |
|
hidden=p_tup[2], |
|
enc_output=enc_output, |
|
) |
|
|
|
|
|
|
|
dec_output = nn.functional.log_softmax(dec_output, dim=1) |
|
|
|
pred_topk = torch.topk(dec_output, k=beam_width, dim=1) |
|
|
|
for i in range(beam_width): |
|
sig_logsmx_ = p_tup[0] + pred_topk.values[0][i] |
|
|
|
seq_tensor_ = torch.cat((p_tup[1], pred_topk.indices[0][i].view(1))) |
|
|
|
cur_pred_list.append((sig_logsmx_, seq_tensor_, dec_hidden)) |
|
|
|
cur_pred_list.sort(key=_avg_score, reverse=True) |
|
top_pred_list = cur_pred_list[:beam_width] |
|
|
|
|
|
end_flags_ = [1 if t[1][-1] == end_tok else 0 for t in top_pred_list] |
|
if beam_width == sum(end_flags_): |
|
break |
|
|
|
pred_tnsr_list = [t[1] for t in top_pred_list] |
|
|
|
return pred_tnsr_list |
|
|
|
def passive_beam_inference(self, src, beam_width=7, max_tgt_sz=50): |
|
""" |
|
Passive Beam search based inference |
|
src: (sequence_len) |
|
""" |
|
|
|
def _avg_score(p_tup): |
|
"""Used for Sorting |
|
TODO: Dividing by length of sequence power alpha as hyperparam |
|
""" |
|
return p_tup[0] |
|
|
|
def _beam_search_topk(topk_obj, start_tok, beam_width): |
|
"""search for sequence with maxim prob |
|
topk_obj[x]: .values & .indices shape:(1, beam_width) |
|
""" |
|
|
|
top_pred_list = [ |
|
(0, start_tok.unsqueeze(0)), |
|
] |
|
|
|
for obj in topk_obj: |
|
new_lst_ = list() |
|
for itm in top_pred_list: |
|
for i in range(beam_width): |
|
sig_logsmx_ = itm[0] + obj.values[0][i] |
|
seq_tensor_ = torch.cat((itm[1], obj.indices[0][i].view(1))) |
|
new_lst_.append((sig_logsmx_, seq_tensor_)) |
|
|
|
new_lst_.sort(key=_avg_score, reverse=True) |
|
top_pred_list = new_lst_[:beam_width] |
|
return top_pred_list |
|
|
|
batch_size = 1 |
|
start_tok = src[0] |
|
end_tok = src[-1] |
|
src_sz = torch.tensor([len(src)]) |
|
src_ = src.unsqueeze(0) |
|
|
|
enc_output, enc_hidden = self.encoder(src_, src_sz) |
|
|
|
if self.pass_enc2dec_hid: |
|
|
|
dec_hidden = enc_hidden |
|
else: |
|
|
|
dec_hidden = None |
|
|
|
|
|
dec_input = start_tok.view(1, 1) |
|
|
|
topk_obj = [] |
|
for t in range(max_tgt_sz): |
|
dec_output, dec_hidden, aw = self.decoder( |
|
dec_input, |
|
dec_hidden, |
|
enc_output, |
|
) |
|
|
|
|
|
|
|
dec_output = nn.functional.log_softmax(dec_output, dim=1) |
|
|
|
pred_topk = torch.topk(dec_output, k=beam_width, dim=1) |
|
|
|
topk_obj.append(pred_topk) |
|
|
|
|
|
dec_input = pred_topk.indices[0][0].view(1, 1) |
|
if torch.eq(dec_input, end_tok): |
|
break |
|
|
|
top_pred_list = _beam_search_topk(topk_obj, start_tok, beam_width) |
|
pred_tnsr_list = [t[1] for t in top_pred_list] |
|
|
|
return pred_tnsr_list |
|
|
|
|
|
class GlyphStrawboss: |
|
def __init__(self, glyphs="en"): |
|
"""list of letters in a language in unicode |
|
lang: List with unicodes |
|
""" |
|
if glyphs == "en": |
|
|
|
self.glyphs = [chr(alpha) for alpha in range(97, 123)] + ["é", "è", "á"] |
|
else: |
|
self.dossier = json.load(open(glyphs, encoding="utf-8")) |
|
self.numsym_map = self.dossier["numsym_map"] |
|
self.glyphs = self.dossier["glyphs"] |
|
|
|
self.indoarab_num = [chr(alpha) for alpha in range(48, 58)] |
|
|
|
self.char2idx = {} |
|
self.idx2char = {} |
|
self._create_index() |
|
|
|
def _create_index(self): |
|
|
|
self.char2idx["_"] = 0 |
|
self.char2idx["$"] = 1 |
|
self.char2idx["#"] = 2 |
|
self.char2idx["*"] = 3 |
|
self.char2idx["'"] = 4 |
|
self.char2idx["%"] = 5 |
|
self.char2idx["!"] = 6 |
|
self.char2idx["?"] = 7 |
|
self.char2idx[":"] = 8 |
|
self.char2idx[" "] = 9 |
|
self.char2idx["-"] = 10 |
|
self.char2idx[","] = 11 |
|
self.char2idx["."] = 12 |
|
self.char2idx["("] = 13 |
|
self.char2idx[")"] = 14 |
|
self.char2idx["/"] = 15 |
|
self.char2idx["^"] = 16 |
|
|
|
for idx, char in enumerate(self.indoarab_num): |
|
self.char2idx[char] = idx + 17 |
|
|
|
for idx, char in enumerate(self.glyphs): |
|
self.char2idx[char] = idx + 27 |
|
|
|
|
|
for char, idx in self.char2idx.items(): |
|
self.idx2char[idx] = char |
|
|
|
def size(self): |
|
return len(self.char2idx) |
|
|
|
def word2xlitvec(self, word): |
|
"""Converts given string of gyphs(word) to vector(numpy) |
|
Also adds tokens for start and end |
|
""" |
|
try: |
|
vec = [self.char2idx["$"]] |
|
for i in list(word): |
|
vec.append(self.char2idx[i]) |
|
vec.append(self.char2idx["#"]) |
|
|
|
vec = np.asarray(vec, dtype=np.int64) |
|
return vec |
|
|
|
except Exception as error: |
|
print("Error In word:", word, "Error Char not in Token:", error) |
|
sys.exit() |
|
|
|
def xlitvec2word(self, vector): |
|
"""Converts vector(numpy) to string of glyphs(word)""" |
|
char_list = [] |
|
for i in vector: |
|
char_list.append(self.idx2char[i]) |
|
|
|
word = "".join(char_list).replace("$", "").replace("#", "") |
|
word = word.replace("_", "").replace("*", "") |
|
return word |
|
|
|
|
|
class XlitPiston: |
|
""" |
|
For handling prediction & post-processing of transliteration for a single language |
|
Class dependency: Seq2Seq, GlyphStrawboss |
|
Global Variables: F_DIR |
|
""" |
|
|
|
def __init__( |
|
self, weight_path, tglyph_cfg_file, iglyph_cfg_file="en", device="cpu" |
|
): |
|
|
|
self.device = device |
|
self.in_glyph_obj = GlyphStrawboss(iglyph_cfg_file) |
|
self.tgt_glyph_obj = GlyphStrawboss(glyphs=tglyph_cfg_file) |
|
|
|
self._numsym_set = set( |
|
json.load(open(tglyph_cfg_file, encoding="utf-8"))["numsym_map"].keys() |
|
) |
|
self._inchar_set = set("abcdefghijklmnopqrstuvwxyzéèá") |
|
self._natscr_set = set().union( |
|
self.tgt_glyph_obj.glyphs, sum(self.tgt_glyph_obj.numsym_map.values(), []) |
|
) |
|
|
|
|
|
input_dim = self.in_glyph_obj.size() |
|
output_dim = self.tgt_glyph_obj.size() |
|
enc_emb_dim = 300 |
|
dec_emb_dim = 300 |
|
enc_hidden_dim = 512 |
|
dec_hidden_dim = 512 |
|
rnn_type = "lstm" |
|
enc2dec_hid = True |
|
attention = True |
|
enc_layers = 1 |
|
dec_layers = 2 |
|
m_dropout = 0 |
|
enc_bidirect = True |
|
enc_outstate_dim = enc_hidden_dim * (2 if enc_bidirect else 1) |
|
|
|
enc = Encoder( |
|
input_dim=input_dim, |
|
embed_dim=enc_emb_dim, |
|
hidden_dim=enc_hidden_dim, |
|
rnn_type=rnn_type, |
|
layers=enc_layers, |
|
dropout=m_dropout, |
|
device=self.device, |
|
bidirectional=enc_bidirect, |
|
) |
|
dec = Decoder( |
|
output_dim=output_dim, |
|
embed_dim=dec_emb_dim, |
|
hidden_dim=dec_hidden_dim, |
|
rnn_type=rnn_type, |
|
layers=dec_layers, |
|
dropout=m_dropout, |
|
use_attention=attention, |
|
enc_outstate_dim=enc_outstate_dim, |
|
device=self.device, |
|
) |
|
self.model = Seq2Seq(enc, dec, pass_enc2dec_hid=enc2dec_hid, device=self.device) |
|
self.model = self.model.to(self.device) |
|
weights = torch.load(weight_path, map_location=torch.device(self.device)) |
|
|
|
self.model.load_state_dict(weights) |
|
self.model.eval() |
|
|
|
def character_model(self, word, beam_width=1): |
|
in_vec = torch.from_numpy(self.in_glyph_obj.word2xlitvec(word)).to(self.device) |
|
|
|
p_out_list = self.model.active_beam_inference(in_vec, beam_width=beam_width) |
|
result = [ |
|
self.tgt_glyph_obj.xlitvec2word(out.cpu().numpy()) for out in p_out_list |
|
] |
|
|
|
|
|
return result |
|
|
|
def numsym_model(self, seg): |
|
"""tgt_glyph_obj.numsym_map[x] returns a list object""" |
|
if len(seg) == 1: |
|
return [seg] + self.tgt_glyph_obj.numsym_map[seg] |
|
|
|
a = [self.tgt_glyph_obj.numsym_map[n][0] for n in seg] |
|
return [seg] + ["".join(a)] |
|
|
|
def _word_segementer(self, sequence): |
|
|
|
sequence = sequence.lower() |
|
accepted = set().union(self._numsym_set, self._inchar_set, self._natscr_set) |
|
|
|
|
|
segment = [] |
|
idx = 0 |
|
seq_ = list(sequence) |
|
while len(seq_): |
|
|
|
temp = "" |
|
while len(seq_) and seq_[0] in self._numsym_set: |
|
temp += seq_[0] |
|
seq_.pop(0) |
|
if temp != "": |
|
segment.append(temp) |
|
|
|
|
|
temp = "" |
|
while len(seq_) and seq_[0] in self._natscr_set: |
|
temp += seq_[0] |
|
seq_.pop(0) |
|
if temp != "": |
|
segment.append(temp) |
|
|
|
|
|
temp = "" |
|
while len(seq_) and seq_[0] in self._inchar_set: |
|
temp += seq_[0] |
|
seq_.pop(0) |
|
if temp != "": |
|
segment.append(temp) |
|
|
|
temp = "" |
|
while len(seq_) and seq_[0] not in accepted: |
|
temp += seq_[0] |
|
seq_.pop(0) |
|
if temp != "": |
|
segment.append(temp) |
|
|
|
return segment |
|
|
|
def inferencer(self, sequence, beam_width=10): |
|
|
|
seg = self._word_segementer(sequence[:120]) |
|
lit_seg = [] |
|
|
|
p = 0 |
|
while p < len(seg): |
|
if seg[p][0] in self._natscr_set: |
|
lit_seg.append([seg[p]]) |
|
p += 1 |
|
|
|
elif seg[p][0] in self._inchar_set: |
|
lit_seg.append(self.character_model(seg[p], beam_width=beam_width)) |
|
p += 1 |
|
|
|
elif seg[p][0] in self._numsym_set: |
|
lit_seg.append(self.numsym_model(seg[p])) |
|
p += 1 |
|
else: |
|
lit_seg.append([seg[p]]) |
|
p += 1 |
|
|
|
|
|
|
|
if len(lit_seg) == 1: |
|
final_result = lit_seg[0] |
|
|
|
elif len(lit_seg) == 2: |
|
final_result = [""] |
|
for seg in lit_seg: |
|
new_result = [] |
|
for s in seg: |
|
for f in final_result: |
|
new_result.append(f + s) |
|
final_result = new_result |
|
|
|
else: |
|
new_result = [] |
|
for seg in lit_seg: |
|
new_result.append(seg[0]) |
|
final_result = ["".join(new_result)] |
|
|
|
return final_result |
|
|
|
|
|
class XlitEngine: |
|
""" |
|
For Managing the top level tasks and applications of transliteration |
|
Global Variables: F_DIR |
|
""" |
|
|
|
def __init__(self, lang2use="hi", config_path="models/default_lineup.json"): |
|
lineup = json.load(open(os.path.join(F_DIR, config_path), encoding="utf-8")) |
|
models_path = os.path.join(F_DIR, "models") |
|
self.lang_config = {} |
|
if lang2use in lineup: |
|
self.lang_config[lang2use] = lineup[lang2use] |
|
else: |
|
raise Exception( |
|
"XlitError: The entered Langauge code not found. Available are {}".format( |
|
lineup.keys() |
|
) |
|
) |
|
self.langs = {} |
|
self.lang_model = {} |
|
for la in self.lang_config: |
|
try: |
|
print("Loading {}...".format(la)) |
|
self.lang_model[la] = XlitPiston( |
|
weight_path=os.path.join( |
|
models_path, self.lang_config[la]["weight"] |
|
), |
|
tglyph_cfg_file=os.path.join( |
|
models_path, self.lang_config[la]["script"] |
|
), |
|
iglyph_cfg_file="en", |
|
) |
|
self.langs[la] = self.lang_config[la]["name"] |
|
except Exception as error: |
|
print("XlitError: Failure in loading {} \n".format(la), error) |
|
print(XlitError.loading_err.value) |
|
|
|
def translit_word(self, eng_word, lang_code="hi", topk=7, beam_width=10): |
|
if eng_word == "": |
|
return [] |
|
if lang_code in self.langs: |
|
try: |
|
res_list = self.lang_model[lang_code].inferencer( |
|
eng_word, beam_width=beam_width |
|
) |
|
return res_list[:topk] |
|
|
|
except Exception as error: |
|
print("XlitError:", traceback.format_exc()) |
|
print(XlitError.internal_err.value) |
|
return XlitError.internal_err |
|
else: |
|
print("XlitError: Unknown Langauge requested", lang_code) |
|
print(XlitError.lang_err.value) |
|
return XlitError.lang_err |
|
|
|
def translit_sentence(self, eng_sentence, lang_code="hi", beam_width=10): |
|
if eng_sentence == "": |
|
return [] |
|
|
|
if lang_code in self.langs: |
|
try: |
|
out_str = "" |
|
for word in eng_sentence.split(): |
|
res_ = self.lang_model[lang_code].inferencer( |
|
word, beam_width=beam_width |
|
) |
|
out_str = out_str + res_[0] + " " |
|
return out_str[:-1] |
|
|
|
except Exception as error: |
|
print("XlitError:", traceback.format_exc()) |
|
print(XlitError.internal_err.value) |
|
return XlitError.internal_err |
|
|
|
else: |
|
print("XlitError: Unknown Langauge requested", lang_code) |
|
print(XlitError.lang_err.value) |
|
return XlitError.lang_err |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
engine = XlitEngine() |
|
y = engine.translit_sentence("Hello World !") |
|
print(y) |
|
|