File size: 2,078 Bytes
d9c22e0
 
 
 
 
ca25754
 
d9c22e0
 
 
ca25754
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import streamlit as st
import torch
from transformers import BertTokenizer, BertModel
from huggingface_hub import hf_hub_url, cached_download

def get_cls_layer(repo_id="furrutiav/beto_coherence"):
  config_file_url = hf_hub_url(repo_id, filename="cls_layer.torch")
  value = cached_download(config_file_url)
  return torch.load(value, map_location=torch.device('cpu'))
 
cls_layer = get_cls_layer()

beto_model = BertModel.from_pretrained("furrutiav/beto_coherence", revision="df96f50cfb1e3f7923912a25b1c3a865116fae4a")

beto_tokenizer = BertTokenizer.from_pretrained("furrutiav/beto_coherence", revision="df96f50cfb1e3f7923912a25b1c3a865116fae4a", do_lower_case=False)

e = beto_model.eval()

def preproccesing(Q, A, maxlen=60):
    Q = " ".join(str(Q).replace("\n", " ").split())
    A = " ".join(str(A).replace("\n", " ").split())
    Q = Q if Q != "" else "nan"
    A = A if A != "" else "nan"
    
    tokens1 = beto_tokenizer.tokenize(Q)
    tokens1 = ['[CLS]'] + tokens1 + ['[SEP]']
    if len(tokens1) < maxlen:
        tokens1 = tokens1 + ['[PAD]' for _ in range(maxlen - len(tokens1))]
    else:
        tokens1 = tokens1[:maxlen-1] + ['[SEP]']

    tokens2 = beto_tokenizer.tokenize(A)
    tokens2 = tokens2 + ['[SEP]']
    if len(tokens2) < maxlen:
        tokens2 = tokens2 + ['[PAD]' for _ in range(maxlen - len(tokens2))]
    else:
        tokens2 = tokens2[:maxlen-1] + ['[SEP]']

    tokens = tokens1+tokens2
    tokens_ids = beto_tokenizer.convert_tokens_to_ids(tokens)
    tokens_ids_tensor = torch.tensor(tokens_ids)

    attn_mask = (tokens_ids_tensor != 1).long()
    return tokens_ids_tensor, attn_mask

def C1Classifier(Q, A, is_probs=True):
    tokens_ids_tensor, attn_mask = preproccesing(Q, A)
    cont_reps = beto_model(tokens_ids_tensor.unsqueeze(0), attention_mask = attn_mask.unsqueeze(0))
    cls_rep = cont_reps.last_hidden_state[:, 0]
    logits = cls_layer(cls_rep)
    probs = torch.sigmoid(logits)
    soft_probs = probs.argmax(1)
    if is_probs:
        return probs.detach().numpy()[0]
    else:
        return soft_probs.numpy()[0]