File size: 4,350 Bytes
046f8fb 36460c9 046f8fb 36460c9 046f8fb 36460c9 046f8fb 36460c9 046f8fb ed4687b 36460c9 046f8fb ed4687b 046f8fb ed4687b 046f8fb 36460c9 046f8fb 36460c9 046f8fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import logging
import torch
from tqdm.auto import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
def load_model_and_tokenizer(model_name):
"""
load_model_and_tokenizer - a function that loads a model and tokenizer from huggingface
Args:
model_name (str): the name of the model to load
Returns:
AutoModelForSeq2SeqLM: the model
AutoTokenizer: the tokenizer
"""
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
# low_cpu_mem_usage=True,
# use_cache=False,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = model.to("cuda") if torch.cuda.is_available() else model
logging.info(f"Loaded model {model_name}")
return model, tokenizer
def summarize(ids, mask, model, tokenizer, model_arch, **kwargs):
"""
summarize_and_score - given a batch of ids and a mask, return a summary and a score for the summary
Args:
ids (): the batch of ids
mask (): the attention mask for the batch
model (): the model to use for summarization
tokenizer (): the tokenizer to use for summarization
model
Returns:
str: the summary of the batch
"""
ids = ids[None, :]
mask = mask[None, :]
input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
if model_arch == 'LED':
global_attention_mask = torch.zeros_like(attention_mask)
# put global attention on <s> token
global_attention_mask[:, 0] = 1
summary_pred_ids = model.generate(
input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
return_dict_in_generate=True,
**kwargs,
)
else:
summary_pred_ids = model.generate(
input_ids,
attention_mask=attention_mask,
return_dict_in_generate=True,
**kwargs,
)
summary = tokenizer.batch_decode(
summary_pred_ids.sequences,
skip_special_tokens=True,
remove_invalid_values=True,
)
return summary
def summarize_via_tokenbatches(
input_text: str,
model,
tokenizer,
batch_length=2048,
batch_stride=16,
**kwargs,
):
"""
summarize_via_tokenbatches - a function that takes a string and returns a summary
Args:
input_text (str): the text to summarize
model (): the model to use for summarization
tokenizer (): the tokenizer to use for summarization
batch_length (int, optional): the length of each batch. Defaults to 2048.
batch_stride (int, optional): the stride of each batch. Defaults to 16. The stride is the number of tokens that overlap between batches.
Returns:
str: the summary
"""
# log all input parameters
if batch_length < 512:
batch_length = 512
print("WARNING: batch_length was set to 512")
print(
f"input parameters: {kwargs}, batch_length={batch_length}, batch_stride={batch_stride}"
)
encoded_input = tokenizer(
input_text,
padding="max_length",
truncation=True,
max_length=batch_length,
stride=batch_stride,
return_overflowing_tokens=True,
add_special_tokens=False,
return_tensors="pt",
)
in_id_arr, att_arr = encoded_input.input_ids, encoded_input.attention_mask
gen_summaries = []
pbar = tqdm(total=len(in_id_arr))
for _id, _mask in zip(in_id_arr, att_arr):
if model=='Blaise-g/led_pubmed_sumpubmed_1' or model=='Blaise-g/led_large_sumpbumed_scitldr':
model_arch = 'LED'
else:
model_arch = 'LongT5'
result = summarize(
ids=_id,
mask=_mask,
model=model,
model_arch=model_arch,
tokenizer=tokenizer,
**kwargs,
)
rate = round(float((len(input_text)-len(result))/len(input_text)), 3)
_sum = {
"input_tokens": _id,
"summary": result,
"compression_rate": rate,
}
gen_summaries.append(_sum)
print(f"\t{result[0]}\nRate:\t{rate}")
pbar.update()
pbar.close()
return gen_summaries |