taka-yamakoshi commited on
Commit
a0471c4
1 Parent(s): 1f8519e

model type

Browse files
Files changed (1) hide show
  1. app.py +18 -9
app.py CHANGED
@@ -47,10 +47,10 @@ def load_css(file_name):
47
  st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
48
 
49
  @st.cache(show_spinner=True,allow_output_mutation=True)
50
- def load_model():
51
- tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
52
  #model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
53
- model = AlbertForMaskedLM.from_pretrained('albert-xxlarge-v2')
54
  return tokenizer,model
55
 
56
  def clear_data():
@@ -167,14 +167,23 @@ def run_intervention(interventions,batch_size,model,masked_ids_option_1,masked_i
167
  if __name__=='__main__':
168
  wide_setup()
169
  load_css('style.css')
170
- tokenizer,model = load_model()
171
- num_layers, num_heads = model.config.num_hidden_layers, model.config.num_attention_heads
172
- mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
173
-
174
- main_area = st.empty()
175
 
176
  if 'page_status' not in st.session_state:
177
- st.session_state['page_status'] = 'type_in'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
  if st.session_state['page_status']=='type_in':
180
  show_instruction('1. Type in the sentences and click "Tokenize"',fontsize=16)
 
47
  st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
48
 
49
  @st.cache(show_spinner=True,allow_output_mutation=True)
50
+ def load_model(model_name):
51
+ tokenizer = AlbertTokenizer.from_pretrained(model_name)
52
  #model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
53
+ model = AlbertForMaskedLM.from_pretrained(model_name)
54
  return tokenizer,model
55
 
56
  def clear_data():
 
167
  if __name__=='__main__':
168
  wide_setup()
169
  load_css('style.css')
 
 
 
 
 
170
 
171
  if 'page_status' not in st.session_state:
172
+ st.session_state['page_status'] = 'model_selection'
173
+
174
+ if st.session_state['page_status']=='model_selection':
175
+ model_name = st.selectbox('Please select the model from below.',
176
+ ('bert-base-uncased','bert-large-cased',
177
+ 'roberta-base','roberta-large',
178
+ 'albert-base-v2','albert-large-v2','albert-xlarge-v2','albert-xxlarge-v2'),index=3)
179
+ st.sesstion_state['model_name'] = model_name
180
+ if st.button('Confirm',key='model_name'):
181
+ st.session_state['page_status'] = 'type_in'
182
+ st.experimental_rerun()
183
+
184
+ tokenizer,model = load_model(st.session_state['model_name'])
185
+ num_layers, num_heads = model.config.num_hidden_layers, model.config.num_attention_heads
186
+ mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
187
 
188
  if st.session_state['page_status']=='type_in':
189
  show_instruction('1. Type in the sentences and click "Tokenize"',fontsize=16)