taka-yamakoshi
update app
2b49fe2
raw
history blame
No virus
1.71 kB
import numpy as np
import pandas as pd
import time
import streamlit as st
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn.functional as F
from transformers import AlbertTokenizer
from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
if __name__=='__main__':
# Config
max_width = 1500
padding_top = 0
padding_right = 2
padding_bottom = 0
padding_left = 2
define_margins = f"""
<style>
.appview-container .main .block-container{{
max-width: {max_width}px;
padding-top: {padding_top}rem;
padding-right: {padding_right}rem;
padding-left: {padding_left}rem;
padding-bottom: {padding_bottom}rem;
}}
</style>
"""
hide_table_row_index = """
<style>
tbody th {display:none}
.blank {display:none}
</style>
"""
st.markdown(define_margins, unsafe_allow_html=True)
st.markdown(hide_table_row_index, unsafe_allow_html=True)
model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
input_ids = tokenizer('This is a sample sentence.',return_tensors='pt')
input_ids[0][4] = mask_id
with torch.no_grad():
outputs = model(input_ids)
logprobs = F.log_softmax(outputs.logits, dim = -1)
st.write(logprobs.shape)
preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1).item() for probs in logprobs[0]]
st.write([tokenizer.decode([token]) for token in preds])