|
```python |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
|
|
tokenizer = AutoTokenizer.from_pretrained('allenai/longformer-scico') |
|
model = AutoModelForSequenceClassification.from_pretrained('allenai/longformer-scico') |
|
|
|
start_token = tokenizer.convert_tokens_to_ids("<m>") |
|
end_token = tokenizer.convert_tokens_to_aids("</m>") |
|
|
|
def get_global_attention(input_ids): |
|
global_attention_mask = torch.zeros(input_ids.shape) |
|
global_attention_mask[:, 0] = 1 # global attention to the CLS token |
|
start = torch.nonzero(input_ids == start_token) # global attention to the <m> token |
|
end = torch.nonzero(input_ids == end_token) # global attention to the </m> token |
|
globs = torch.cat((start, end)) |
|
value = torch.ones(globs.shape[0]) |
|
global_attention_mask.index_put_(tuple(globs.t()), value) |
|
return global_attention_mask |
|
|
|
|
|
``` |