taka-yamakoshi
commited on
Commit
•
2b49fe2
1
Parent(s):
ced96a8
update app
Browse files
app.py
CHANGED
@@ -1,13 +1,16 @@
|
|
|
|
1 |
import pandas as pd
|
|
|
2 |
import streamlit as st
|
3 |
-
import numpy as np
|
4 |
import matplotlib.pyplot as plt
|
5 |
import seaborn as sns
|
|
|
6 |
import torch
|
7 |
import torch.nn.functional as F
|
|
|
8 |
from transformers import AlbertTokenizer
|
9 |
-
import time
|
10 |
|
|
|
11 |
|
12 |
if __name__=='__main__':
|
13 |
|
@@ -38,5 +41,16 @@ if __name__=='__main__':
|
|
38 |
st.markdown(define_margins, unsafe_allow_html=True)
|
39 |
st.markdown(hide_table_row_index, unsafe_allow_html=True)
|
40 |
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
import pandas as pd
|
3 |
+
import time
|
4 |
import streamlit as st
|
|
|
5 |
import matplotlib.pyplot as plt
|
6 |
import seaborn as sns
|
7 |
+
|
8 |
import torch
|
9 |
import torch.nn.functional as F
|
10 |
+
|
11 |
from transformers import AlbertTokenizer
|
|
|
12 |
|
13 |
+
from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
|
14 |
|
15 |
if __name__=='__main__':
|
16 |
|
|
|
41 |
st.markdown(define_margins, unsafe_allow_html=True)
|
42 |
st.markdown(hide_table_row_index, unsafe_allow_html=True)
|
43 |
|
44 |
+
model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
|
45 |
+
tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
|
46 |
+
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
|
47 |
+
|
48 |
+
input_ids = tokenizer('This is a sample sentence.',return_tensors='pt')
|
49 |
+
input_ids[0][4] = mask_id
|
50 |
+
|
51 |
+
with torch.no_grad():
|
52 |
+
outputs = model(input_ids)
|
53 |
+
logprobs = F.log_softmax(outputs.logits, dim = -1)
|
54 |
+
st.write(logprobs.shape)
|
55 |
+
preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1).item() for probs in logprobs[0]]
|
56 |
+
st.write([tokenizer.decode([token]) for token in preds])
|