entity-extraction / utils.py
RishuD7's picture
first commit
122d428
raw
history blame
3.31 kB
import itertools
import torch
import numpy as np
from tqdm.auto import tqdm
def get_char_probs(texts, predictions, tokenizer):
"""
Maps prediction from encoded offset mapping to the text
Prediction = 466 sequence length * batch
text = 768 * batch
Using offset mapping [(0, 4), ] -- 466
creates results that is size of texts
for each text result[i]
result[0, 4] = pred[0] like wise for all
"""
results = [np.zeros(len(t)) for t in texts]
for i, (text, prediction) in enumerate(zip(texts, predictions)):
encoded = tokenizer(text,
add_special_tokens=True,
return_offsets_mapping=True)
for idx, (offset_mapping, pred) in enumerate(zip(encoded['offset_mapping'], prediction)):
start = offset_mapping[0]
end = offset_mapping[1]
results[i][start:end] = pred
return results
def get_results(char_probs, th=0.5):
"""
Get the list of probabilites with size of text
And then get the index of the characters which are more than th
example:
char_prob = [0.1, 0.1, 0.9, 0.9, 0.9, 0.9, 0.2, 0.2, 0.2, 0.7, 0.7, 0.7] ## length == 766
where > 0.5 index ## [ 2, 3, 4, 5, 9, 10, 11]
Groupby same one -- [[2, 3, 4, 5], [9, 10, 11]]
And get the max and min and output the results
"""
results = []
for char_prob in char_probs:
result = np.where(char_prob >= th)[0] + 1
result = [list(g) for _, g in itertools.groupby(result, key=lambda n, c=itertools.count(): n - next(c))]
result = [f"{min(r)} {max(r)}" for r in result]
result = ";".join(result)
results.append(result)
return results
def get_predictions(results):
"""
Will get the location, as a string, just like location in the df
results = ['2 5', '9 11']
loop through, split it and save it as start and end and store it in array
"""
predictions = []
for result in results:
prediction = []
if result != "":
for loc in [s.split() for s in result.split(';')]:
start, end = int(loc[0]), int(loc[1])
prediction.append([start, end])
predictions.append(prediction)
return predictions
def inference_fn(test_loader, model, device):
preds = []
model.eval()
model.to(device)
tk0 = tqdm(test_loader, total=len(test_loader))
for inputs in tk0:
for k, v in inputs.items():
inputs[k] = v.to(device)
with torch.no_grad():
y_preds = model(inputs)
preds.append(y_preds.sigmoid().numpy())
predictions = np.concatenate(preds)
return predictions
def get_text(context, indexes):
if (indexes):
if ';' in indexes:
list_indexes = indexes.split(';')
answer = ''
for idx in list_indexes:
start_index = int(idx.split(' ')[0])
end_index = int(idx.split(' ')[1])
answer += ' '
answer += context[start_index:end_index]
return answer
else:
start_index = int(indexes.split(' ')[0])
end_index = int(indexes.split(' ')[1])
return context[start_index:end_index]
else:
return 'Not found in this Context'