Spaces:
Running
Running
File size: 4,400 Bytes
87e5c9c 9350787 87e5c9c 9350787 87e5c9c 9350787 87e5c9c 9350787 87e5c9c 9350787 87e5c9c 9350787 87e5c9c 9350787 87e5c9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import logging
import torch
from tqdm.auto import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
def load_model_and_tokenizer(model_name):
"""
load_model_and_tokenizer - a function that loads a model and tokenizer from huggingface
Args:
model_name (str): the name of the model to load
Returns:
AutoModelForSeq2SeqLM: the model
AutoTokenizer: the tokenizer
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
# low_cpu_mem_usage=True,
# use_cache=False,
).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
logging.info(f"Loaded model {model_name} to {device}")
return model, tokenizer
def summarize_and_score(ids, mask, model, tokenizer, is_general_attention_model=True, **kwargs):
"""
summarize_and_score - given a batch of ids and a mask, return a summary and a score for the summary
Args:
ids (): the batch of ids
mask (): the attention mask for the batch
model (): the model to use for summarization
tokenizer (): the tokenizer to use for summarization
is_general_attention_model (bool, optional): whether the model is a general attention model. Defaults to True.
Returns:
str: the summary of the batch
"""
ids = ids[None, :]
mask = mask[None, :]
input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
global_attention_mask = torch.zeros_like(attention_mask)
# put global attention on <s> token
global_attention_mask[:, 0] = 1
if is_general_attention_model:
summary_pred_ids = model.generate(
input_ids,
attention_mask=attention_mask,
output_scores=True,
return_dict_in_generate=True,
**kwargs,
)
else:
summary_pred_ids = model.generate(
input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
output_scores=True,
return_dict_in_generate=True,
**kwargs,
)
summary = tokenizer.batch_decode(
summary_pred_ids.sequences,
skip_special_tokens=True,
remove_invalid_values=True,
)
score = round(summary_pred_ids.sequences_scores.cpu().numpy()[0], 4)
return summary, score
def summarize_via_tokenbatches(
input_text: str,
model,
tokenizer,
batch_length=2048,
batch_stride=16,
**kwargs,
):
"""
summarize_via_tokenbatches - a function that takes a string and returns a summary
Args:
input_text (str): the text to summarize
model (): the model to use for summarizationz
tokenizer (): the tokenizer to use for summarization
batch_length (int, optional): the length of each batch. Defaults to 2048.
batch_stride (int, optional): the stride of each batch. Defaults to 16. The stride is the number of tokens that overlap between batches.
Returns:
str: the summary
"""
# log all input parameters
if batch_length < 512:
batch_length = 512
print("WARNING: batch_length was set to 512")
print(
f"input parameters: {kwargs}, batch_length={batch_length}, batch_stride={batch_stride}"
)
encoded_input = tokenizer(
input_text,
padding="max_length",
truncation=True,
max_length=batch_length,
stride=batch_stride,
return_overflowing_tokens=True,
add_special_tokens=False,
return_tensors="pt",
)
in_id_arr, att_arr = encoded_input.input_ids, encoded_input.attention_mask
gen_summaries = []
pbar = tqdm(total=len(in_id_arr))
for _id, _mask in zip(in_id_arr, att_arr):
result, score = summarize_and_score(
ids=_id,
mask=_mask,
model=model,
tokenizer=tokenizer,
**kwargs,
)
score = round(float(score), 4)
_sum = {
"input_tokens": _id,
"summary": result,
"summary_score": score,
}
gen_summaries.append(_sum)
print(f"\t{result[0]}\nScore:\t{score}")
pbar.update()
pbar.close()
return gen_summaries
|