radiobee-aligner / radiobee /gen_row_alignment.py
freemt
Update before sent-align
4c04f50
raw
history blame
4.23 kB
"""Gen proper alignment for a given triple_set.
cmat = fetch_sent_corr(src, tgt)
src_len, tgt_len = np.array(cmat).shape
r_ali = gen_row_alignment(cmat, tgt_len, src_len) # note the order
src[r_ali[1]], tgt[r_ali[0]], r_ali[2]
or !!! (targer, source)
cmat = fetch_sent_corr(tgt, src) # note the order
src_len, tgt_len = np.array(cmat).shape
r_ali = gen_row_alignment(cmat, src_len, tgt_len)
src[r_ali[0]], tgt[r_ali[1]], r_ali[2]
---
src_txt = 'data/wu_ch2_en.txt'
tgt_txt = 'data/wu_ch2_zh.txt'
assert Path(src_txt).exists()
assert Path(tgt_txt).exists()
src_text, _ = load_paras(src_txt)
tgt_text, _ = load_paras(tgt_txt)
cos_matrix = gen_cos_matrix(src_text, tgt_text)
t_set, m_matrix = find_aligned_pairs(cos_matrix0, thr=0.4, matrix=True)
resu = gen_row_alignment(t_set, src_len, tgt_len)
resu = np.array(resu)
idx = -1
idx += 1; (resu[idx], src_text[int(resu[idx, 0])],
tgt_text[int(resu[idx, 1])]) if all(resu[idx]) else resu[idx]
idx += 1; i0, i1, i2 = resu[idx]; '***' if i0 == ''
else src_text[int(i0)], '***' if i1 == '' else tgt_text[int(i1)], ''
if i2 == '' else i2
"""
# pylint: disable=line-too-long, unused-variable
from typing import List, Union
# natural extrapolation with slope equal to 1
from itertools import zip_longest as zip_longest_middle
import numpy as np
from logzero import logger
# from tinybee.zip_longest_middle import zip_longest_middle
# from tinybee.zip_longest_middle import zip_longest_middle
# from tinybee.find_pairs import find_pairs
# logger = logging.getLogger(__name__)
# logger.addHandler(logging.NullHandler())
def gen_row_alignment( # pylint: disable=too-many-locals
t_set,
src_len,
tgt_len,
# ) -> List[Tuple[Union[str, int], Union[str, int], Union[str, float]]]:
) -> List[List[Union[str, float]]]:
"""Gen proper rows for given triple_set.
Arguments:
[t_set {np.array or list}] -- [nll matrix]
[src_len {int}] -- numb of source texts (para/sents)
[tgt_len {int}] -- numb of target texts (para/sents)
Returns:
[np.array] -- [proper rows]
"""
t_set = np.array(t_set, dtype="object")
# len0 = src_len
# len1 tgt text length, must be provided
len1 = tgt_len
# rearrange t_set as buff in increasing order
buff = [[-1, -1, ""]] #
idx_t = 0
# for elm in t_set:
# start with bigger value from the 3rd col
y00, yargmax, ymax = zip(*t_set)
ymax_ = np.array(ymax).copy()
reset_v = np.min(ymax_) - 1
for count in range(tgt_len):
argmax = np.argmax(ymax_)
# reset
ymax_[argmax] = reset_v
idx_t = argmax
elm = t_set[idx_t]
logger.debug("%s: %s, %s", count, idx_t, elm)
# find loc to insert
elm0, elm1, elm2 = elm
idx = -1
for idx, loc in enumerate(buff):
if loc[0] > elm0:
break
else:
idx += 1 # last
# make sure elm1 is within the range
# prev elm1 < elm1 < next elm1
if elm1 > buff[idx - 1][1]:
try: # overflow possible (idx + 1 in # last)
next_elm = buff[idx][1]
except IndexError:
next_elm = len1
if elm1 < next_elm:
# insert '' if necessary
# using zip_longest_middle
buff.insert(
idx, [elm0, elm1, elm2],
)
# logger.debug('---')
idx_t += 1
# if idx_t == 24: # 20:
# break
# remove [-1, -1]
# buff.pop(0)
# buff = np.array(buff, dtype='object')
# take care of the tail
buff += [[src_len, tgt_len, ""]]
resu = []
# merit = []
for idx, elm in enumerate(buff[1:]):
idx1 = idx + 1
elm0_, elm1_, elm2_ = buff[idx1 - 1] # idx starts from 0
elm0, elm1, elm2 = elm
del elm2_, elm2
tmp0 = zip_longest_middle(
list(range(elm0_ + 1, elm0)), list(range(elm1_ + 1, elm1)), fillvalue="",
)
# convet to list entries & attache merit
tmp = [list(t_elm) + [""] for t_elm in tmp0]
# update resu
resu += tmp + [buff[idx1]]
# remove the last entry
return resu[:-1]