Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Replabel transforms for use with flashlight's ASG criterion. | |
""" | |
def replabel_symbol(i): | |
""" | |
Replabel symbols used in flashlight, currently just "1", "2", ... | |
This prevents training with numeral tokens, so this might change in the future | |
""" | |
return str(i) | |
def pack_replabels(tokens, dictionary, max_reps): | |
""" | |
Pack a token sequence so that repeated symbols are replaced by replabels | |
""" | |
if len(tokens) == 0 or max_reps <= 0: | |
return tokens | |
replabel_value_to_idx = [0] * (max_reps + 1) | |
for i in range(1, max_reps + 1): | |
replabel_value_to_idx[i] = dictionary.index(replabel_symbol(i)) | |
result = [] | |
prev_token = -1 | |
num_reps = 0 | |
for token in tokens: | |
if token == prev_token and num_reps < max_reps: | |
num_reps += 1 | |
else: | |
if num_reps > 0: | |
result.append(replabel_value_to_idx[num_reps]) | |
num_reps = 0 | |
result.append(token) | |
prev_token = token | |
if num_reps > 0: | |
result.append(replabel_value_to_idx[num_reps]) | |
return result | |
def unpack_replabels(tokens, dictionary, max_reps): | |
""" | |
Unpack a token sequence so that replabels are replaced by repeated symbols | |
""" | |
if len(tokens) == 0 or max_reps <= 0: | |
return tokens | |
replabel_idx_to_value = {} | |
for i in range(1, max_reps + 1): | |
replabel_idx_to_value[dictionary.index(replabel_symbol(i))] = i | |
result = [] | |
prev_token = -1 | |
for token in tokens: | |
try: | |
for _ in range(replabel_idx_to_value[token]): | |
result.append(prev_token) | |
prev_token = -1 | |
except KeyError: | |
result.append(token) | |
prev_token = token | |
return result | |