nazneen commited on
Commit
945ee1d
·
1 Parent(s): 050aca6

multi error clusters

Browse files
Files changed (1) hide show
  1. app.py +29 -17
app.py CHANGED
@@ -3,7 +3,7 @@
3
  import numpy as np
4
  import pandas as pd
5
  import torch
6
- import json
7
  from tqdm import tqdm
8
  from math import floor
9
  from datasets import load_dataset
@@ -104,7 +104,8 @@ def quant_panel(embedding_df):
104
  st.markdown("* Each **point** is an input example.")
105
  st.markdown("* Gray points have low-loss and the colored have high-loss. High-loss instances are clustered using **kmeans** and each color represents a cluster.")
106
  st.markdown("* The **shape** of each point reflects the label category -- positive (diamond) or negative sentiment (circle).")
107
- st.altair_chart(data_comparison(down_samp(embedding_df)), use_container_width=True)
 
108
 
109
 
110
  def frequent_tokens(data, tokenizer, loss_quantile=0.95, top_k=200, smoothing=0.005):
@@ -136,7 +137,7 @@ def frequent_tokens(data, tokenizer, loss_quantile=0.95, top_k=200, smoothing=0.
136
  for i, (token) in enumerate(tokens_sorted[:top_k]):
137
  top_tokens.append(['%10s' % (tokenizer.decode(token)), '%.4f' % (token_frequencies[token]), '%.4f' % (
138
  token_frequencies_error[token]), '%4.2f' % (token_lrs[token])])
139
- return pd.DataFrame(top_tokens, columns=['Token', 'Freq', 'Freq error slice', 'lrs'])
140
 
141
 
142
  @st.cache(ttl=600)
@@ -160,12 +161,12 @@ def clustering(data,num_clusters):
160
  return data, assigned_clusters
161
 
162
  def kmeans(df, num_clusters=3):
163
- data_hl = df.loc[df['slice'] == 'high-loss']
164
- data_kmeans,clusters = clustering(data_hl,num_clusters)
165
- merged = pd.merge(df, data_kmeans, left_index=True, right_index=True, how='outer', suffixes=('', '_y'))
166
- merged.drop(merged.filter(regex='_y$').columns.tolist(),axis=1,inplace=True)
167
- merged['cluster'] = merged['cluster'].fillna(num_clusters).astype('int')
168
- return merged
169
 
170
  def distance_from_centroid(row):
171
  return sdist.norm(row['embedding'] - row['centroid'].tolist())
@@ -173,16 +174,16 @@ def distance_from_centroid(row):
173
  @st.cache(ttl=600)
174
  def topic_distribution(weights, smoothing=0.01):
175
  topic_frequencies = defaultdict(float)
176
- topic_frequencies_spotlight = defaultdict(float)
177
  weights_uniform = np.full_like(weights, 1 / len(weights))
178
  num_examples = len(weights)
179
  for i in range(num_examples):
180
  example = dataset[i]
181
  category = example['title']
182
  topic_frequencies[category] += weights_uniform[i]
183
- topic_frequencies_spotlight[category] += weights[i]
184
 
185
- topic_ratios = {c: (smoothing + topic_frequencies_spotlight[c]) / (
186
  smoothing + topic_frequencies[c]) for c in topic_frequencies}
187
 
188
  categories_sorted = map(lambda x: x[0], sorted(
@@ -191,11 +192,9 @@ def topic_distribution(weights, smoothing=0.01):
191
  topic_distr = []
192
  for category in categories_sorted:
193
  topic_distr.append(['%.3f' % topic_frequencies[category], '%.3f' %
194
- topic_frequencies_spotlight[category], '%.2f' % topic_ratios[category], '%s' % category])
195
 
196
  return pd.DataFrame(topic_distr, columns=['Overall frequency', 'Error frequency', 'Ratio', 'Category'])
197
- # for category in categories_sorted:
198
- # return(topic_frequencies[category], topic_frequencies_spotlight[category], topic_ratios[category], category)
199
 
200
  def populate_session(dataset,model):
201
  data_df = read_file_to_df('./assets/data/'+dataset+ '_'+ model+'.parquet')
@@ -239,13 +238,17 @@ if __name__ == "__main__":
239
  #populate_session(dataset, model)
240
  data_df = read_file_to_df('./assets/data/'+dataset+ '_'+ model+'.parquet')
241
  loss_quantile = st.sidebar.slider(
242
- "Loss Quantile", min_value=0.5, max_value=1.0,step=0.01,value=0.95
243
  )
 
244
  data_df['loss'] = data_df['loss'].astype(float)
245
  losses = data_df['loss']
246
  high_loss = losses.quantile(loss_quantile)
247
  data_df['slice'] = 'high-loss'
248
  data_df['slice'] = data_df['slice'].where(data_df['loss'] > high_loss, 'low-loss')
 
 
 
249
 
250
  with lcol:
251
  st.markdown('<h3>Error Slices</h3>',unsafe_allow_html=True)
@@ -279,7 +282,16 @@ if __name__ == "__main__":
279
 
280
  if run_kmeans == 'True':
281
  with st.spinner(text='running kmeans...'):
282
- merged = kmeans(data_df,num_clusters=num_clusters)
 
 
 
 
 
 
 
 
 
283
 
284
  with st.spinner(text='loading visualization...'):
285
  quant_panel(merged)
 
3
  import numpy as np
4
  import pandas as pd
5
  import torch
6
+ import math
7
  from tqdm import tqdm
8
  from math import floor
9
  from datasets import load_dataset
 
104
  st.markdown("* Each **point** is an input example.")
105
  st.markdown("* Gray points have low-loss and the colored have high-loss. High-loss instances are clustered using **kmeans** and each color represents a cluster.")
106
  st.markdown("* The **shape** of each point reflects the label category -- positive (diamond) or negative sentiment (circle).")
107
+ #st.altair_chart(data_comparison(down_samp(embedding_df)), use_container_width=True)
108
+ st.altair_chart(data_comparison(embedding_df), use_container_width=True)
109
 
110
 
111
  def frequent_tokens(data, tokenizer, loss_quantile=0.95, top_k=200, smoothing=0.005):
 
137
  for i, (token) in enumerate(tokens_sorted[:top_k]):
138
  top_tokens.append(['%10s' % (tokenizer.decode(token)), '%.4f' % (token_frequencies[token]), '%.4f' % (
139
  token_frequencies_error[token]), '%4.2f' % (token_lrs[token])])
140
+ return pd.DataFrame(top_tokens, columns=['Token', 'Freq', 'Freq error slice', 'Ratio w/ smoothing'])
141
 
142
 
143
  @st.cache(ttl=600)
 
161
  return data, assigned_clusters
162
 
163
  def kmeans(df, num_clusters=3):
164
+ #data_hl = df.loc[df['slice'] == 'high-loss']
165
+ data_kmeans,clusters = clustering(df,num_clusters)
166
+ #merged = pd.merge(df, data_kmeans, left_index=True, right_index=True, how='outer', suffixes=('', '_y'))
167
+ #merged.drop(merged.filter(regex='_y$').columns.tolist(),axis=1,inplace=True)
168
+ #merged['cluster'] = merged['cluster'].fillna(num_clusters).astype('int')
169
+ return data_kmeans
170
 
171
  def distance_from_centroid(row):
172
  return sdist.norm(row['embedding'] - row['centroid'].tolist())
 
174
  @st.cache(ttl=600)
175
  def topic_distribution(weights, smoothing=0.01):
176
  topic_frequencies = defaultdict(float)
177
+ topic_frequencies_error= defaultdict(float)
178
  weights_uniform = np.full_like(weights, 1 / len(weights))
179
  num_examples = len(weights)
180
  for i in range(num_examples):
181
  example = dataset[i]
182
  category = example['title']
183
  topic_frequencies[category] += weights_uniform[i]
184
+ topic_frequencies_error[category] += weights[i]
185
 
186
+ topic_ratios = {c: (smoothing + topic_frequencies_error[c]) / (
187
  smoothing + topic_frequencies[c]) for c in topic_frequencies}
188
 
189
  categories_sorted = map(lambda x: x[0], sorted(
 
192
  topic_distr = []
193
  for category in categories_sorted:
194
  topic_distr.append(['%.3f' % topic_frequencies[category], '%.3f' %
195
+ topic_frequencies_error[category], '%.2f' % topic_ratios[category], '%s' % category])
196
 
197
  return pd.DataFrame(topic_distr, columns=['Overall frequency', 'Error frequency', 'Ratio', 'Category'])
 
 
198
 
199
  def populate_session(dataset,model):
200
  data_df = read_file_to_df('./assets/data/'+dataset+ '_'+ model+'.parquet')
 
238
  #populate_session(dataset, model)
239
  data_df = read_file_to_df('./assets/data/'+dataset+ '_'+ model+'.parquet')
240
  loss_quantile = st.sidebar.slider(
241
+ "Loss Quantile", min_value=0.5, max_value=1.0,step=0.01,value=0.99
242
  )
243
+ data_df = data_df.drop(data_df[data_df.pred == data_df.label].index) #drop rows that are not errors
244
  data_df['loss'] = data_df['loss'].astype(float)
245
  losses = data_df['loss']
246
  high_loss = losses.quantile(loss_quantile)
247
  data_df['slice'] = 'high-loss'
248
  data_df['slice'] = data_df['slice'].where(data_df['loss'] > high_loss, 'low-loss')
249
+ data_hl = data_df.drop(data_df[data_df['slice'] == 'low-loss'].index) #drop rows that are not hl
250
+ data_ll = data_df.drop(data_df[data_df['slice'] == 'high-loss'].index)
251
+ df_list = [d for _, d in data_hl.groupby(['label'])] # this is to allow clustering over each error type. fp, fn for binary classification
252
 
253
  with lcol:
254
  st.markdown('<h3>Error Slices</h3>',unsafe_allow_html=True)
 
282
 
283
  if run_kmeans == 'True':
284
  with st.spinner(text='running kmeans...'):
285
+ merged = pd.DataFrame()
286
+ ind=0
287
+ for df in df_list:
288
+ #num_clusters= int(math.sqrt(len(df)/2))
289
+ kmeans_df = kmeans(df,num_clusters=num_clusters)
290
+ #print(kmeans_df.loc[kmeans_df['cluster'].idxmax()])
291
+ kmeans_df['cluster'] = kmeans_df['cluster'] + ind*num_clusters
292
+ ind = ind+1
293
+ merged = pd.concat([merged, kmeans_df])
294
+ merged = pd.concat([merged, data_ll])
295
 
296
  with st.spinner(text='loading visualization...'):
297
  quant_panel(merged)