File size: 3,904 Bytes
046f8fb fc85387 046f8fb fc85387 046f8fb 10d5b39 ff43405 10d5b39 ff43405 10d5b39 ff43405 067fa13 10d5b39 046f8fb aa0bfa7 fc85387 046f8fb 10d5b39 fc85387 046f8fb 78b9ac3 046f8fb aa0bfa7 046f8fb fc85387 046f8fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import logging
import torch
from tqdm.auto import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
def load_model_and_tokenizer(model_name):
"""
load_model_and_tokenizer - a function that loads a model and tokenizer from huggingface
Args:
model_name (str): the name of the model to load
Returns:
AutoModelForSeq2SeqLM: the model
AutoTokenizer: the tokenizer
"""
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
# low_cpu_mem_usage=True,
# use_cache=False,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = model.to("cuda") if torch.cuda.is_available() else model
logging.info(f"Loaded model {model_name}")
return model, tokenizer
def summarize(ids, mask, model, tokenizer, **kwargs):
"""
summarize - given a batch of ids and a mask, returns a summary and the token length of the output summary
Args:
ids (): the batch of ids
mask (): the attention mask for the batch
model (): the model to use for summarization
tokenizer (): the tokenizer to use for summarization
Returns:
str: the summary of the batch
"""
ids = ids[None, :]
mask = mask[None, :]
input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
#global_attention_mask = torch.zeros_like(attention_mask)
# put global attention on <s> token
#global_attention_mask[:, 0] = 1
summary_pred_ids = model.generate(
input_ids,
attention_mask=attention_mask,
#global_attention_mask=global_attention_mask,
return_dict_in_generate=True,
**kwargs,
)
summary = tokenizer.batch_decode(
summary_pred_ids.sequences,
skip_special_tokens=True,
remove_invalid_values=True,
)
len_res = len(summary_pred_ids.sequences.cpu().numpy()[0])
return summary, len_res
def summarize_via_tokenbatches(
input_text: str,
model,
tokenizer,
batch_length=2048,
batch_stride=16,
**kwargs,
):
"""
summarize_via_tokenbatches - a function that takes a string and returns a summary
Args:
input_text (str): the text to summarize
model (): the model to use for summarization
tokenizer (): the tokenizer to use for summarization
batch_length (int, optional): the length of each batch. Defaults to 2048.
batch_stride (int, optional): the stride of each batch. Defaults to 16. The stride is the number of tokens that overlap between batches.
Returns:
str: the summary
"""
# log all input parameters
if batch_length < 512:
batch_length = 512
print("WARNING: batch_length was set to 512")
print(
f"input parameters: {kwargs}, batch_length={batch_length}, batch_stride={batch_stride}"
)
encoded_input = tokenizer(
input_text,
padding="max_length",
truncation=True,
max_length=batch_length,
stride=batch_stride,
return_overflowing_tokens=True,
add_special_tokens=False,
return_tensors="pt",
)
in_id_arr, att_arr = encoded_input.input_ids, encoded_input.attention_mask
gen_summaries = []
pbar = tqdm(total=len(in_id_arr))
for _id, _mask in zip(in_id_arr, att_arr):
result, l = summarize(
ids=_id,
mask=_mask,
model=model,
tokenizer=tokenizer,
**kwargs,
)
rate = round(float((len(_id)-l)/len(_id)),3)
_sum = {
"input_tokens": _id,
"summary": result,
"compression_rate": rate,
}
gen_summaries.append(_sum)
print(f"\t{result[0]}\nCompression:\t{rate}")
pbar.update()
pbar.close()
return gen_summaries |