Spaces:
Build error
Build error
"""Gen proper alignment for a given triple_set. | |
cmat = fetch_sent_corr(src, tgt) | |
src_len, tgt_len = np.array(cmat).shape | |
r_ali = gen_row_alignment(cmat, tgt_len, src_len) # note the order | |
src[r_ali[1]], tgt[r_ali[0]], r_ali[2] | |
or !!! (targer, source) | |
cmat = fetch_sent_corr(tgt, src) # note the order | |
src_len, tgt_len = np.array(cmat).shape | |
r_ali = gen_row_alignment(cmat, src_len, tgt_len) | |
src[r_ali[0]], tgt[r_ali[1]], r_ali[2] | |
--- | |
src_txt = 'data/wu_ch2_en.txt' | |
tgt_txt = 'data/wu_ch2_zh.txt' | |
assert Path(src_txt).exists() | |
assert Path(tgt_txt).exists() | |
src_text, _ = load_paras(src_txt) | |
tgt_text, _ = load_paras(tgt_txt) | |
cos_matrix = gen_cos_matrix(src_text, tgt_text) | |
t_set, m_matrix = find_aligned_pairs(cos_matrix0, thr=0.4, matrix=True) | |
resu = gen_row_alignment(t_set, src_len, tgt_len) | |
resu = np.array(resu) | |
idx = -1 | |
idx += 1; (resu[idx], src_text[int(resu[idx, 0])], | |
tgt_text[int(resu[idx, 1])]) if all(resu[idx]) else resu[idx] | |
idx += 1; i0, i1, i2 = resu[idx]; '***' if i0 == '' | |
else src_text[int(i0)], '***' if i1 == '' else tgt_text[int(i1)], '' | |
if i2 == '' else i2 | |
""" | |
# pylint: disable=line-too-long | |
from typing import List, Union | |
# natural extrapolation with slope equal to 1 | |
from itertools import zip_longest as zip_longest_middle | |
import numpy as np | |
from logzero import logger | |
# from tinybee.zip_longest_middle import zip_longest_middle | |
# from tinybee.zip_longest_middle import zip_longest_middle | |
# from tinybee.find_pairs import find_pairs | |
# logger = logging.getLogger(__name__) | |
# logger.addHandler(logging.NullHandler()) | |
def gen_row_alignment( # pylint: disable=too-many-locals | |
t_set, | |
src_len, | |
tgt_len, | |
# ) -> List[Tuple[Union[str, int], Union[str, int], Union[str, float]]]: | |
) -> List[List[Union[str, float]]]: | |
"""Gen proper rows for given triple_set. | |
Arguments: | |
[t_set {np.array or list}] -- [nll matrix] | |
[src_len {int}] -- numb of source texts (para/sents) | |
[tgt_len {int}] -- numb of target texts (para/sents) | |
Returns: | |
[np.array] -- [proper rows] | |
""" | |
t_set = np.array(t_set, dtype="object") | |
# len0 = src_len | |
# len1 tgt text length, must be provided | |
len1 = tgt_len | |
# rearrange t_set as buff in increasing order | |
buff = [[-1, -1, ""]] # | |
idx_t = 0 | |
# for elm in t_set: | |
# start with bigger value from the 3rd col | |
y00, yargmax, ymax = zip(*t_set) | |
ymax_ = np.array(ymax).copy() | |
reset_v = np.min(ymax_) - 1 | |
for count in range(tgt_len): | |
argmax = np.argmax(ymax_) | |
# reset | |
ymax_[argmax] = reset_v | |
idx_t = argmax | |
elm = t_set[idx_t] | |
logger.debug("%s: %s, %s", count, idx_t, elm) | |
# find loc to insert | |
elm0, elm1, elm2 = elm | |
idx = -1 | |
for idx, loc in enumerate(buff): | |
if loc[0] > elm0: | |
break | |
else: | |
idx += 1 # last | |
# make sure elm1 is within the range | |
# prev elm1 < elm1 < next elm1 | |
if elm1 > buff[idx - 1][1]: | |
try: # overflow possible (idx + 1 in # last) | |
next_elm = buff[idx][1] | |
except IndexError: | |
next_elm = len1 | |
if elm1 < next_elm: | |
# insert '' if necessary | |
# using zip_longest_middle | |
buff.insert( | |
idx, [elm0, elm1, elm2], | |
) | |
# logger.debug('---') | |
idx_t += 1 | |
# if idx_t == 24: # 20: | |
# break | |
# remove [-1, -1] | |
# buff.pop(0) | |
# buff = np.array(buff, dtype='object') | |
# take care of the tail | |
buff += [[src_len, tgt_len, ""]] | |
resu = [] | |
# merit = [] | |
for idx, elm in enumerate(buff[1:]): | |
idx1 = idx + 1 | |
elm0_, elm1_, elm2_ = buff[idx1 - 1] # idx starts from 0 | |
elm0, elm1, elm2 = elm | |
del elm2_, elm2 | |
tmp0 = zip_longest_middle( | |
list(range(elm0_ + 1, elm0)), list(range(elm1_ + 1, elm1)), fillvalue="", | |
) | |
# convet to list entries & attache merit | |
tmp = [list(t_elm) + [""] for t_elm in tmp0] | |
# update resu | |
resu += tmp + [buff[idx1]] | |
# remove the last entry | |
return resu[:-1] | |