funbot / util_funcs.py
Andrey Vorozhko
Import torch in utils functions
65df198
import random
import torch
def getLengthParam(text: str, tokenizer) -> str:
tokens_count = len(tokenizer.encode(text))
if tokens_count <= 15:
len_param = '1'
elif tokens_count <= 50:
len_param = '2'
elif tokens_count <= 256:
len_param = '3'
else:
len_param = '-'
return len_param
# Эта функция вычисляет длину ожидаемого ответа на основе инпута
def calcAnswerLengthByProbability(lengthId):
# Вспомогательная функция, для работы с вероятностями
# На вход подаем список веротностей для длинного ответа (3), среднего(2), короткого 1
def getLenght(probList):
rndNum = random.randrange(start=0, stop=100, step=1)
if 0 <= rndNum <= probList[0]:
return 3
elif probList[0] < rndNum <= probList[1]:
return 2
else:
return 1
return {
lengthId == '3' or lengthId == '-': getLenght([60, 90]), # до 60 - 3, от 60 до 90 2, остальное - 1
lengthId == '2': getLenght([25, 75]), # до 25 - 3, от 25 до 75 - 2, остальное - 2
lengthId == '1': getLenght([20, 50]), # до 20 - 3, от 20 до 50 - 2, остальное - 1
}[True]
# Функция для обрезки контекста
# tensor - входной тензор
# size - сколько ПОСЛЕДНИХ ответов нужно оставить
def cropContext(tensor, size):
# переводим в размерность, удобную для работы
tensor = tensor[-1]
# Список, содержащий начала предложений
beginList = []
for i, item in enumerate(tensor):
if (i < len(tensor) - 5 and item == 96 and tensor[i + 2] == 96 and tensor[i + 4] == 96):
beginList.append(i)
if (len(beginList) < size):
return torch.unsqueeze(tensor, 0)
neededIndex = beginList[-size]
# Возвращаем в нужном нам формате (добавляем одну размерность)
return torch.unsqueeze(tensor[neededIndex:], 0)