File size: 1,713 Bytes
2b49fe2
efeee8a
2b49fe2
efeee8a
 
 
2b49fe2
efeee8a
 
2b49fe2
efeee8a
 
2b49fe2
efeee8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b49fe2
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import numpy as np
import pandas as pd
import time
import streamlit as st
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn.functional as F

from transformers import AlbertTokenizer

from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM

if __name__=='__main__':

    # Config
    max_width = 1500
    padding_top = 0
    padding_right = 2
    padding_bottom = 0
    padding_left = 2

    define_margins = f"""
    <style>
        .appview-container .main .block-container{{
            max-width: {max_width}px;
            padding-top: {padding_top}rem;
            padding-right: {padding_right}rem;
            padding-left: {padding_left}rem;
            padding-bottom: {padding_bottom}rem;
        }}
    </style>
    """
    hide_table_row_index = """
                <style>
                tbody th {display:none}
                .blank {display:none}
                </style>
                """
    st.markdown(define_margins, unsafe_allow_html=True)
    st.markdown(hide_table_row_index, unsafe_allow_html=True)

    model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
    tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
    mask_id = tokenizer('[MASK]').input_ids[1:-1][0]

    input_ids = tokenizer('This is a sample sentence.',return_tensors='pt')
    input_ids[0][4] = mask_id

    with torch.no_grad():
        outputs = model(input_ids)
    logprobs = F.log_softmax(outputs.logits, dim = -1)
    st.write(logprobs.shape)
    preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1).item() for probs in logprobs[0]]
    st.write([tokenizer.decode([token]) for token in preds])