File size: 2,510 Bytes
9841cab
bc39e07
 
9841cab
bc39e07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a41e9e6
fc0609e
 
 
 
bc39e07
 
 
ee6b045
 
bc39e07
ee6b045
bc39e07
 
abbefff
bc39e07
 
 
 
ee6b045
 
bc39e07
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

tokenizer = AutoTokenizer.from_pretrained('allenai/longformer-scico')
model = AutoModelForSequenceClassification.from_pretrained('allenai/longformer-scico')

start_token = tokenizer.convert_tokens_to_ids("<m>")
end_token = tokenizer.convert_tokens_to_ids("</m>")

def get_global_attention(input_ids):
    global_attention_mask = torch.zeros(input_ids.shape)
    global_attention_mask[:, 0] = 1  # global attention to the CLS token
    start = torch.nonzero(input_ids == start_token) # global attention to the <m> token
    end = torch.nonzero(input_ids == end_token) # global attention to the </m> token
    globs = torch.cat((start, end))
    value = torch.ones(globs.shape[0])
    global_attention_mask.index_put_(tuple(globs.t()), value)
    return global_attention_mask
  
def inference(m1,m2):
  b = {}
  m1 = m1
  m2 = m2
  
  inputs = m1 + " </s></s> " + m2  
  
  tokens = tokenizer(inputs, return_tensors='pt')
  global_attention_mask = get_global_attention(tokens['input_ids'])
  
  with torch.no_grad():
      output = model(tokens['input_ids'], tokens['attention_mask'], global_attention_mask)
      
  scores = torch.softmax(output.logits, dim=-1)
  listscore = scores.tolist()
  print(listscore)
  b['not related'] = listscore[0][0]
  b['coref'] = listscore[0][1]
  b['parent'] = listscore[0][2]
  b['child'] = listscore[0][3]
  return b

title = "Longformer-scico"
description = "Gradio demo for Longformer-scico. To use it, simply add your text, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://openreview.net/forum?id=OFLbgUP04nC'>SciCo: Hierarchical Cross-Document Coreference for Scientific Concepts</a> | <a href='https://github.com/ariecattan/SciCo'>Github Repo</a></p>"

examples = [["In this paper we present the results of an experiment in <m> automatic concept and definition extraction </m> from written sources of law using relatively simple natural methods.","This task is important since many natural language processing (NLP) problems, such as <m> information extraction </m>, summarization and dialogue."]]
gr.Interface(
    inference, 
    [gr.inputs.Textbox(label="Input1"),gr.inputs.Textbox(label="Input2")], 
    gr.outputs.Label(label="Output"),
    title=title,
    description=description,
    article=article,
    enable_queue=True,
    examples=examples
    ).launch(debug=True)