import numpy as np
import pandas as pd
import time
import streamlit as st
import matplotlib.pyplot as plt
import seaborn as sns
import jax
import jax.numpy as jnp
from transformers import AlbertTokenizer
from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
@st.cache(show_spinner=True,allow_output_mutation=True)
def load_model():
tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
return tokenizer,model
if __name__=='__main__':
# Config
max_width = 1500
padding_top = 0
padding_right = 2
padding_bottom = 0
padding_left = 2
define_margins = f"""
"""
hide_table_row_index = """
"""
st.markdown(define_margins, unsafe_allow_html=True)
st.markdown(hide_table_row_index, unsafe_allow_html=True)
tokenizer,model = load_model()
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
input_ids = tokenizer('This is a sample sentence.',return_tensors='np').input_ids
input_ids[0][4] = mask_id
outputs = model(input_ids)
logprobs = jax.nn.log_softmax(outputs.logits, axis = -1)
st.write(logprobs.shape)
preds = [np.random.choice(np.arange(len(probs)),p=np.exp(probs)/np.sum(np.exp(probs))) for probs in logprobs[0]]
st.write([tokenizer.decode([token]) for token in preds])