import pandas as pd import streamlit as st import numpy as np import matplotlib.pyplot as plt import seaborn as sns import torch import torch.nn.functional as F from sklearn.decomposition import PCA from sklearn.manifold import TSNE from sentence_transformers import SentenceTransformer from transformers import BertTokenizer,BertForMaskedLM import io import time @st.cache(show_spinner=True,allow_output_mutation=True) def load_sentence_model(): sentence_model = SentenceTransformer('paraphrase-distilroberta-base-v1') return sentence_model @st.cache(show_spinner=True,allow_output_mutation=True) def load_model(model_name): if model_name.startswith('bert'): tokenizer = BertTokenizer.from_pretrained(model_name) model = BertForMaskedLM.from_pretrained(model_name) model.eval() return tokenizer,model @st.cache(show_spinner=False) def load_data(sentence_num): df = pd.read_csv('tsne_out.csv') df = df.loc[lambda d: (d['sentence_num']==sentence_num)&(d['iter_num']<1000)] return df.reset_index() #@st.cache(show_spinner=False) def mask_prob(model,mask_id,sentences,position,temp=1): masked_sentences = sentences.clone() masked_sentences[:, position] = mask_id with torch.no_grad(): logits = model(masked_sentences)[0] return F.log_softmax(logits[:, position] / temp, dim = -1) #@st.cache(show_spinner=False) def sample_words(probs,pos,sentences): candidates = [[tokenizer.decode([candidate]),torch.exp(probs)[0,candidate].item()] for candidate in torch.argsort(probs[0],descending=True)[:10]] df = pd.DataFrame(data=candidates,columns=['word','prob']) chosen_words = torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) new_sentences = sentences.clone() new_sentences[:, pos] = chosen_words return new_sentences, df def run_chains(tokenizer,model,mask_id,input_text,num_steps): init_sent = tokenizer(input_text,return_tensors='pt')['input_ids'] seq_len = init_sent.shape[1] sentence = init_sent.clone() data_list = [] st.sidebar.write('Generating samples...') st.sidebar.write('This takes ~1 min for 1000 steps with ~10 token sentences') chain_progress = st.sidebar.progress(0) for step_id in range(num_steps): chain_progress.progress((step_id+1)/num_steps) pos = torch.randint(seq_len-2,size=(1,)).item()+1 #data_list.append([step_id,' '.join([tokenizer.decode([token]) for token in sentence[0]]),pos]) data_list.append([step_id,tokenizer.decode([token for token in sentence[0]]),pos]) probs = mask_prob(model,mask_id,sentence,pos) sentence,_ = sample_words(probs,pos,sentence) return pd.DataFrame(data=data_list,columns=['step','sentence','next_sample_loc']) #@st.cache(show_spinner=True,allow_output_mutation=True) def show_tsne_panel(df, step_id): x_tsne, y_tsne = df.x_tsne, df.y_tsne xscale_unit = (max(x_tsne)-min(x_tsne))/10 yscale_unit = (max(y_tsne)-min(y_tsne))/10 xlims = [(min(x_tsne)//xscale_unit-1)*xscale_unit,(max(x_tsne)//xscale_unit+1)*xscale_unit] ylims = [(min(y_tsne)//yscale_unit-1)*yscale_unit,(max(y_tsne)//yscale_unit+1)*yscale_unit] color_list = sns.color_palette('flare',n_colors=int(len(df)*1.2)) fig = plt.figure(figsize=(5,5),dpi=200) ax = fig.add_subplot(1,1,1) ax.plot(x_tsne[:step_id+1],y_tsne[:step_id+1],linewidth=0.2,color='gray',zorder=1) ax.scatter(x_tsne[:step_id+1],y_tsne[:step_id+1],s=5,color=color_list[:step_id+1],zorder=2) ax.scatter(x_tsne[step_id:step_id+1],y_tsne[step_id:step_id+1],s=50,marker='*',color='blue',zorder=3) ax.set_xlim(*xlims) ax.set_ylim(*ylims) ax.axis('off') return fig def run_tsne(chain): st.sidebar.write('Running t-SNE...') st.sidebar.write('This takes ~1 min for 1000 steps with ~10 token sentences') chain = chain.assign(cleaned_sentence=chain.sentence.str.replace(r'\[CLS\] ', '',regex=True).str.replace(r' \[SEP\]', '',regex=True)) sentence_model = load_sentence_model() sentence_embeddings = sentence_model.encode(chain.cleaned_sentence.to_list(), show_progress_bar=False) tsne = TSNE(n_components = 2, n_iter=2000) big_pca = PCA(n_components = 50) tsne_vals = tsne.fit_transform(big_pca.fit_transform(sentence_embeddings)) tsne = pd.concat([chain, pd.DataFrame(tsne_vals, columns = ['x_tsne', 'y_tsne'],index=chain.index)], axis = 1) return tsne def autoplay() : for step_id in range(st.session_state.step_id, len(st.session_state.df), 1): x = st.empty() with x.container(): st.markdown(show_changed_site(), unsafe_allow_html = True) fig = show_tsne_panel(st.session_state.df, step_id) st.session_state.prev_step_id = st.session_state.step_id st.session_state.step_id = step_id #plt.title(f'Step {step_id}')#: {show_changed_site()}') cols = st.columns([1,2,1]) with cols[1]: st.pyplot(fig) time.sleep(.25) x.empty() def initialize_buttons() : buttons = st.sidebar.empty() button_ids = [] with buttons.container() : row1_labels = ['+1','+10','+100','+500'] row1 = st.columns([4,5,6,6]) for col_id,col in enumerate(row1): button_ids.append(col.button(row1_labels[col_id],key=row1_labels[col_id])) row2_labels = ['-1','-10','-100','-500'] row2 = st.columns([4,5,6,6]) for col_id,col in enumerate(row2): button_ids.append(col.button(row2_labels[col_id],key=row2_labels[col_id])) show_candidates_checked = st.checkbox('Show candidates') # Increment if any of them have been pressed increments = np.array([1,10,100,500,-1,-10,-100,-500]) if any(button_ids) : increment_value = increments[np.array(button_ids)][0] st.session_state.prev_step_id = st.session_state.step_id new_step_id = st.session_state.step_id + increment_value st.session_state.step_id = min(len(st.session_state.df) - 1, max(0, new_step_id)) if show_candidates_checked: st.write('Click any word to see each candidate with its probability') show_candidates() def show_candidates(): if 'curr_table' in st.session_state: st.session_state.curr_table.empty() step_id = st.session_state.step_id sentence = df.cleaned_sentence.loc[step_id] input_sent = tokenizer(sentence,return_tensors='pt')['input_ids'] decoded_sent = [tokenizer.decode([token]) for token in input_sent[0]] char_nums = [len(word)+2 for word in decoded_sent] cols = st.columns(char_nums) with cols[0]: st.write(decoded_sent[0]) with cols[-1]: st.write(decoded_sent[-1]) for word_id,(col,word) in enumerate(zip(cols[1:-1],decoded_sent[1:-1])): with col: if st.button(word,key=f'word_{word_id}'): probs = mask_prob(model,mask_id,input_sent,word_id+1) _, candidates_df = sample_words(probs, word_id+1, input_sent) st.session_state.curr_table = st.table(candidates_df) def show_changed_site(): df = st.session_state.df step_id = st.session_state.step_id prev_step_id = st.session_state.prev_step_id curr_sent = df.cleaned_sentence.loc[step_id].split(' ') prev_sent = df.cleaned_sentence.loc[prev_step_id].split(' ') locs = [df.next_sample_loc.to_list()[step_id-1]-1] if 'next_sample_loc' in df else ( [i for i in range(len(curr_sent)) if curr_sent[i] not in prev_sent] ) disp_style = '"font-family:san serif; color:Black; font-size: 20px"' prefix = f'

Step {st.session_state.step_id}:  ' disp = ' '.join([f'{word}' if i in locs else f'{word}' for (i, word) in enumerate(curr_sent)]) suffix = '

' return prefix + disp + suffix def clear_df(): if 'df' in st.session_state: del st.session_state['df'] if __name__=='__main__': # Config max_width = 1500 padding_top = 0 padding_right = 2 padding_bottom = 0 padding_left = 2 define_margins = f""" """ hide_table_row_index = """ """ st.markdown(define_margins, unsafe_allow_html=True) st.markdown(hide_table_row_index, unsafe_allow_html=True) input_type = st.sidebar.radio( label='1. Choose the input type', on_change=clear_df, options=('Use one of the example sentences','Use your own initial sentence') ) # Title st.header("Demo: Probing BERT's priors with serial reproduction chains") # Load BERT tokenizer,model = load_model('bert-base-uncased') mask_id = tokenizer.encode("[MASK]")[1:-1][0] # First step: load the dataframe containing sentences if input_type=='Use one of the example sentences': sentence = st.sidebar.selectbox("Select the inital sentence", ('--- Please select one from below ---', 'About 170 campers attend the camps each week.', "Ali marpet's mother is joy rose.", 'She grew up with three brothers and ten sisters.')) if sentence!='--- Please select one from below ---': if sentence=='About 170 campers attend the camps each week.': sentence_num = 6 elif sentence=='She grew up with three brothers and ten sisters.': sentence_num = 8 elif sentence=="Ali marpet's mother is joy rose." : sentence_num = 2 st.session_state.df = load_data(sentence_num) st.session_state.finished_sampling = True else: sentence = st.sidebar.text_input('Type your own sentence here.',on_change=clear_df) num_steps = st.sidebar.number_input(label='How many steps do you want to run?',value=500) if st.sidebar.button('Run chains'): chain = run_chains(tokenizer, model, mask_id, sentence, num_steps=num_steps) st.session_state.df = run_tsne(chain) st.session_state.finished_sampling = True st.empty().markdown("\ Let's explore sentences from BERT's prior! \ Use the menu to the left to select a pre-generated chain, \ or start a new chain using your own initial sentence.\ " if not 'df' in st.session_state else "\ Use the slider to select a step, or watch the autoplay.\ Click 'Show candidates' to see the top proposals when each word is masked out.\ ") if 'df' in st.session_state: df = st.session_state.df if 'step_id' not in st.session_state: st.session_state.prev_step_id = 0 st.session_state.step_id = 0 explore_type = st.sidebar.radio( '2. Choose how to explore the chain', options=['Click through steps','Autoplay'] ) if explore_type=='Autoplay': st.empty() st.sidebar.empty() autoplay() elif explore_type=='Click through steps': initialize_buttons() with st.container(): st.markdown(show_changed_site(), unsafe_allow_html = True) fig = show_tsne_panel(df, st.session_state.step_id) cols = st.columns([1,2,1]) with cols[1]: st.pyplot(fig)