import numpy as np import pandas as pd import time import streamlit as st import matplotlib.pyplot as plt import seaborn as sns import jax import jax.numpy as jnp from transformers import AlbertTokenizer from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM @st.cache(show_spinner=True,allow_output_mutation=True) def load_model(): tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2') model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True) return tokenizer,model if __name__=='__main__': # Config max_width = 1500 padding_top = 0 padding_right = 2 padding_bottom = 0 padding_left = 2 define_margins = f""" """ hide_table_row_index = """ """ st.markdown(define_margins, unsafe_allow_html=True) st.markdown(hide_table_row_index, unsafe_allow_html=True) tokenizer,model = load_model() mask_id = tokenizer('[MASK]').input_ids[1:-1][0] input_ids = tokenizer('This is a sample sentence.',return_tensors='np').input_ids input_ids[0][4] = mask_id outputs = model(input_ids) logprobs = jax.nn.log_softmax(outputs.logits, axis = -1) st.write(logprobs.shape) preds = [np.random.choice(np.arange(len(probs)),p=np.exp(probs)/np.sum(np.exp(probs))) for probs in logprobs[0]] st.write([tokenizer.decode([token]) for token in preds])