import numpy as np import pandas as pd import time import streamlit as st import matplotlib.pyplot as plt import seaborn as sns import torch import torch.nn.functional as F from transformers import AlbertTokenizer from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM if __name__=='__main__': # Config max_width = 1500 padding_top = 0 padding_right = 2 padding_bottom = 0 padding_left = 2 define_margins = f""" """ hide_table_row_index = """ """ st.markdown(define_margins, unsafe_allow_html=True) st.markdown(hide_table_row_index, unsafe_allow_html=True) model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True) tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2') mask_id = tokenizer('[MASK]').input_ids[1:-1][0] input_ids = tokenizer('This is a sample sentence.',return_tensors='pt') input_ids[0][4] = mask_id with torch.no_grad(): outputs = model(input_ids) logprobs = F.log_softmax(outputs.logits, dim = -1) st.write(logprobs.shape) preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1).item() for probs in logprobs[0]] st.write([tokenizer.decode([token]) for token in preds])