Spaces:
Runtime error
Runtime error
import itertools | |
import torch | |
import numpy as np | |
from tqdm.auto import tqdm | |
def get_char_probs(texts, predictions, tokenizer): | |
""" | |
Maps prediction from encoded offset mapping to the text | |
Prediction = 466 sequence length * batch | |
text = 768 * batch | |
Using offset mapping [(0, 4), ] -- 466 | |
creates results that is size of texts | |
for each text result[i] | |
result[0, 4] = pred[0] like wise for all | |
""" | |
results = [np.zeros(len(t)) for t in texts] | |
for i, (text, prediction) in enumerate(zip(texts, predictions)): | |
encoded = tokenizer(text, | |
add_special_tokens=True, | |
return_offsets_mapping=True) | |
for idx, (offset_mapping, pred) in enumerate(zip(encoded['offset_mapping'], prediction)): | |
start = offset_mapping[0] | |
end = offset_mapping[1] | |
results[i][start:end] = pred | |
return results | |
def get_results(char_probs, th=0.5): | |
""" | |
Get the list of probabilites with size of text | |
And then get the index of the characters which are more than th | |
example: | |
char_prob = [0.1, 0.1, 0.9, 0.9, 0.9, 0.9, 0.2, 0.2, 0.2, 0.7, 0.7, 0.7] ## length == 766 | |
where > 0.5 index ## [ 2, 3, 4, 5, 9, 10, 11] | |
Groupby same one -- [[2, 3, 4, 5], [9, 10, 11]] | |
And get the max and min and output the results | |
""" | |
results = [] | |
for char_prob in char_probs: | |
result = np.where(char_prob >= th)[0] + 1 | |
result = [list(g) for _, g in itertools.groupby(result, key=lambda n, c=itertools.count(): n - next(c))] | |
result = [f"{min(r)} {max(r)}" for r in result] | |
result = ";".join(result) | |
results.append(result) | |
return results | |
def get_predictions(results): | |
""" | |
Will get the location, as a string, just like location in the df | |
results = ['2 5', '9 11'] | |
loop through, split it and save it as start and end and store it in array | |
""" | |
predictions = [] | |
for result in results: | |
prediction = [] | |
if result != "": | |
for loc in [s.split() for s in result.split(';')]: | |
start, end = int(loc[0]), int(loc[1]) | |
prediction.append([start, end]) | |
predictions.append(prediction) | |
return predictions | |
def inference_fn(test_loader, model, device): | |
preds = [] | |
model.eval() | |
model.to(device) | |
tk0 = tqdm(test_loader, total=len(test_loader)) | |
for inputs in tk0: | |
for k, v in inputs.items(): | |
inputs[k] = v.to(device) | |
with torch.no_grad(): | |
y_preds = model(inputs) | |
preds.append(y_preds.sigmoid().numpy()) | |
predictions = np.concatenate(preds) | |
return predictions | |
def get_text(context, indexes): | |
if (indexes): | |
if ';' in indexes: | |
list_indexes = indexes.split(';') | |
answer = '' | |
for idx in list_indexes: | |
start_index = int(idx.split(' ')[0]) | |
end_index = int(idx.split(' ')[1]) | |
answer += ' ' | |
answer += context[start_index:end_index] | |
return answer | |
else: | |
start_index = int(indexes.split(' ')[0]) | |
end_index = int(indexes.split(' ')[1]) | |
return context[start_index:end_index] | |
else: | |
return 'Not found in this Context' | |