taka-yamakoshi commited on
Commit
2b49fe2
1 Parent(s): ced96a8

update app

Browse files
Files changed (1) hide show
  1. app.py +18 -4
app.py CHANGED
@@ -1,13 +1,16 @@
 
1
  import pandas as pd
 
2
  import streamlit as st
3
- import numpy as np
4
  import matplotlib.pyplot as plt
5
  import seaborn as sns
 
6
  import torch
7
  import torch.nn.functional as F
 
8
  from transformers import AlbertTokenizer
9
- import time
10
 
 
11
 
12
  if __name__=='__main__':
13
 
@@ -38,5 +41,16 @@ if __name__=='__main__':
38
  st.markdown(define_margins, unsafe_allow_html=True)
39
  st.markdown(hide_table_row_index, unsafe_allow_html=True)
40
 
41
- from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
42
- model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-base-v2')
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
  import pandas as pd
3
+ import time
4
  import streamlit as st
 
5
  import matplotlib.pyplot as plt
6
  import seaborn as sns
7
+
8
  import torch
9
  import torch.nn.functional as F
10
+
11
  from transformers import AlbertTokenizer
 
12
 
13
+ from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
14
 
15
  if __name__=='__main__':
16
 
 
41
  st.markdown(define_margins, unsafe_allow_html=True)
42
  st.markdown(hide_table_row_index, unsafe_allow_html=True)
43
 
44
+ model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
45
+ tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
46
+ mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
47
+
48
+ input_ids = tokenizer('This is a sample sentence.',return_tensors='pt')
49
+ input_ids[0][4] = mask_id
50
+
51
+ with torch.no_grad():
52
+ outputs = model(input_ids)
53
+ logprobs = F.log_softmax(outputs.logits, dim = -1)
54
+ st.write(logprobs.shape)
55
+ preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1).item() for probs in logprobs[0]]
56
+ st.write([tokenizer.decode([token]) for token in preds])