pszemraj commited on
Commit
9350787
1 Parent(s): 87e5c9c

🚧 update for longt5

Browse files

Signed-off-by: peter szemraj <peterszemraj@gmail.com>

Files changed (1) hide show
  1. summarize.py +23 -13
summarize.py CHANGED
@@ -15,20 +15,19 @@ def load_model_and_tokenizer(model_name):
15
  AutoModelForSeq2SeqLM: the model
16
  AutoTokenizer: the tokenizer
17
  """
18
-
19
  model = AutoModelForSeq2SeqLM.from_pretrained(
20
  model_name,
21
  # low_cpu_mem_usage=True,
22
  # use_cache=False,
23
- )
24
  tokenizer = AutoTokenizer.from_pretrained(model_name)
25
- model = model.to("cuda") if torch.cuda.is_available() else model
26
 
27
- logging.info(f"Loaded model {model_name}")
28
  return model, tokenizer
29
 
30
 
31
- def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
32
  """
33
  summarize_and_score - given a batch of ids and a mask, return a summary and a score for the summary
34
 
@@ -37,6 +36,7 @@ def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
37
  mask (): the attention mask for the batch
38
  model (): the model to use for summarization
39
  tokenizer (): the tokenizer to use for summarization
 
40
 
41
  Returns:
42
  str: the summary of the batch
@@ -52,14 +52,23 @@ def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
52
  # put global attention on <s> token
53
  global_attention_mask[:, 0] = 1
54
 
55
- summary_pred_ids = model.generate(
56
- input_ids,
57
- attention_mask=attention_mask,
58
- global_attention_mask=global_attention_mask,
59
- output_scores=True,
60
- return_dict_in_generate=True,
61
- **kwargs,
62
- )
 
 
 
 
 
 
 
 
 
63
  summary = tokenizer.batch_decode(
64
  summary_pred_ids.sequences,
65
  skip_special_tokens=True,
@@ -70,6 +79,7 @@ def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
70
  return summary, score
71
 
72
 
 
73
  def summarize_via_tokenbatches(
74
  input_text: str,
75
  model,
 
15
  AutoModelForSeq2SeqLM: the model
16
  AutoTokenizer: the tokenizer
17
  """
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
  model = AutoModelForSeq2SeqLM.from_pretrained(
20
  model_name,
21
  # low_cpu_mem_usage=True,
22
  # use_cache=False,
23
+ ).to(device)
24
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
25
 
26
+ logging.info(f"Loaded model {model_name} to {device}")
27
  return model, tokenizer
28
 
29
 
30
+ def summarize_and_score(ids, mask, model, tokenizer, is_general_attention_model=True, **kwargs):
31
  """
32
  summarize_and_score - given a batch of ids and a mask, return a summary and a score for the summary
33
 
 
36
  mask (): the attention mask for the batch
37
  model (): the model to use for summarization
38
  tokenizer (): the tokenizer to use for summarization
39
+ is_general_attention_model (bool, optional): whether the model is a general attention model. Defaults to True.
40
 
41
  Returns:
42
  str: the summary of the batch
 
52
  # put global attention on <s> token
53
  global_attention_mask[:, 0] = 1
54
 
55
+ if is_general_attention_model:
56
+ summary_pred_ids = model.generate(
57
+ input_ids,
58
+ attention_mask=attention_mask,
59
+ output_scores=True,
60
+ return_dict_in_generate=True,
61
+ **kwargs,
62
+ )
63
+ else:
64
+ summary_pred_ids = model.generate(
65
+ input_ids,
66
+ attention_mask=attention_mask,
67
+ global_attention_mask=global_attention_mask,
68
+ output_scores=True,
69
+ return_dict_in_generate=True,
70
+ **kwargs,
71
+ )
72
  summary = tokenizer.batch_decode(
73
  summary_pred_ids.sequences,
74
  skip_special_tokens=True,
 
79
  return summary, score
80
 
81
 
82
+
83
  def summarize_via_tokenbatches(
84
  input_text: str,
85
  model,