import random | |
import torch | |
def getLengthParam(text: str, tokenizer) -> str: | |
tokens_count = len(tokenizer.encode(text)) | |
if tokens_count <= 15: | |
len_param = '1' | |
elif tokens_count <= 50: | |
len_param = '2' | |
elif tokens_count <= 256: | |
len_param = '3' | |
else: | |
len_param = '-' | |
return len_param | |
# Эта функция вычисляет длину ожидаемого ответа на основе инпута | |
def calcAnswerLengthByProbability(lengthId): | |
# Вспомогательная функция, для работы с вероятностями | |
# На вход подаем список веротностей для длинного ответа (3), среднего(2), короткого 1 | |
def getLenght(probList): | |
rndNum = random.randrange(start=0, stop=100, step=1) | |
if 0 <= rndNum <= probList[0]: | |
return 3 | |
elif probList[0] < rndNum <= probList[1]: | |
return 2 | |
else: | |
return 1 | |
return { | |
lengthId == '3' or lengthId == '-': getLenght([60, 90]), # до 60 - 3, от 60 до 90 2, остальное - 1 | |
lengthId == '2': getLenght([25, 75]), # до 25 - 3, от 25 до 75 - 2, остальное - 2 | |
lengthId == '1': getLenght([20, 50]), # до 20 - 3, от 20 до 50 - 2, остальное - 1 | |
}[True] | |
# Функция для обрезки контекста | |
# tensor - входной тензор | |
# size - сколько ПОСЛЕДНИХ ответов нужно оставить | |
def cropContext(tensor, size): | |
# переводим в размерность, удобную для работы | |
tensor = tensor[-1] | |
# Список, содержащий начала предложений | |
beginList = [] | |
for i, item in enumerate(tensor): | |
if (i < len(tensor) - 5 and item == 96 and tensor[i + 2] == 96 and tensor[i + 4] == 96): | |
beginList.append(i) | |
if (len(beginList) < size): | |
return torch.unsqueeze(tensor, 0) | |
neededIndex = beginList[-size] | |
# Возвращаем в нужном нам формате (добавляем одну размерность) | |
return torch.unsqueeze(tensor[neededIndex:], 0) |