File size: 2,491 Bytes
2b49fe2
efeee8a
2b49fe2
efeee8a
 
 
2b49fe2
d82123b
dd4548e
2b49fe2
e87e116
 
efeee8a
e87e116
 
 
 
efeee8a
75f767b
 
 
e87e116
 
75f767b
 
ddf537a
 
 
 
efeee8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75f767b
2b49fe2
 
2092dd1
 
ddf537a
 
e87e116
2b49fe2
e87e116
 
 
2b49fe2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import numpy as np
import pandas as pd
import time
import streamlit as st
import matplotlib.pyplot as plt
import seaborn as sns

import jax
import jax.numpy as jnp

import torch
import torch.nn.functional as F

from transformers import AlbertTokenizer, AlbertForMaskedLM

#from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
from skeleton_modeling_albert import SkeletonAlbertForMaskedLM

@st.cache(show_spinner=True,allow_output_mutation=True)
def load_model():
    tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
    #model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
    model = AlbertForMaskedLM.from_pretrained('albert-xxlarge-v2')
    return tokenizer,model

def clear_data():
    for key in st.session_state:
        del st.session_state[key]

if __name__=='__main__':

    # Config
    max_width = 1500
    padding_top = 0
    padding_right = 2
    padding_bottom = 0
    padding_left = 2

    define_margins = f"""
    <style>
        .appview-container .main .block-container{{
            max-width: {max_width}px;
            padding-top: {padding_top}rem;
            padding-right: {padding_right}rem;
            padding-left: {padding_left}rem;
            padding-bottom: {padding_bottom}rem;
        }}
    </style>
    """
    hide_table_row_index = """
                <style>
                tbody th {display:none}
                .blank {display:none}
                </style>
                """
    st.markdown(define_margins, unsafe_allow_html=True)
    st.markdown(hide_table_row_index, unsafe_allow_html=True)

    tokenizer,model = load_model()
    mask_id = tokenizer('[MASK]').input_ids[1:-1][0]

    sent_1 = st.sidebar.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.',on_change=clear_data)
    sent_2 = st.sidebar.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.',on_change=clear_data)
    input_ids_1 = tokenizer(sent_1).input_ids
    input_ids_2 = tokenizer(sent_2).input_ids
    input_ids = torch.tensor([input_ids_1,input_ids_2])

    outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions = {0:{'lay':[(8,1,[0,1])]}})
    logprobs = F.log_softmax(outputs.logits, dim = -1)
    preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0]]
    st.write([tokenizer.decode([token]) for token in preds])