import numpy as np
import pandas as pd
import time
import streamlit as st
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn.functional as F
from transformers import AlbertTokenizer
from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
if __name__=='__main__':
# Config
max_width = 1500
padding_top = 0
padding_right = 2
padding_bottom = 0
padding_left = 2
define_margins = f"""
"""
hide_table_row_index = """
"""
st.markdown(define_margins, unsafe_allow_html=True)
st.markdown(hide_table_row_index, unsafe_allow_html=True)
model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
input_ids = tokenizer('This is a sample sentence.',return_tensors='pt')
input_ids[0][4] = mask_id
with torch.no_grad():
outputs = model(input_ids)
logprobs = F.log_softmax(outputs.logits, dim = -1)
st.write(logprobs.shape)
preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1).item() for probs in logprobs[0]]
st.write([tokenizer.decode([token]) for token in preds])