Michelle Lam commited on
Commit
da6aa93
1 Parent(s): c302134

Adds comments on key utils functions; preliminary streamlining of cluster results plot

Browse files
Files changed (2) hide show
  1. audit_utils.py +25 -18
  2. server.py +8 -2
audit_utils.py CHANGED
@@ -431,6 +431,7 @@ def plot_class_cond_results(preds_df, breakdown_axis, perf_metric, other_ids, so
431
 
432
  return combined
433
 
 
434
  def show_overall_perf(variant, error_type, cur_user, threshold=TOXIC_THRESHOLD, breakdown_axis=None, topic_vis_method="median"):
435
  # Your perf (calculate using model and testset)
436
  breakdown_axis = readable_to_internal[breakdown_axis]
@@ -447,7 +448,7 @@ def show_overall_perf(variant, error_type, cur_user, threshold=TOXIC_THRESHOLD,
447
  topic_overview_plot_json = json.load(f)
448
  else:
449
  preds_df_mod = preds_df.merge(comments_grouped_full_topic_cat, on="item_id", how="left", suffixes=('_', '_avg'))
450
- if topic_vis_method == "median":
451
  preds_df_mod_grp = preds_df_mod.groupby(["topic_", "user_id"]).median()
452
  elif topic_vis_method == "mean":
453
  preds_df_mod_grp = preds_df_mod.groupby(["topic_", "user_id"]).mean()
@@ -737,7 +738,7 @@ def train_updated_model(model_name, last_label_i, ratings, user, top_n=20, topic
737
 
738
  mae, mse, rmse, avg_diff = user_perf_metrics[model_name]
739
 
740
- cur_preds_df = get_preds_df(cur_model, ["A"], sys_eval_df=ratings_df_full, topic=topic, model_name=model_name) # Just get results for user
741
 
742
  # Save this batch of labels
743
  with open(os.path.join(module_dir, label_dir, f"{last_label_i + 1}.pkl"), "wb") as f:
@@ -827,7 +828,12 @@ def get_predictions_by_user_and_item(predictions):
827
  user_item_preds[(uid, iid)] = est
828
  return user_item_preds
829
 
830
- def get_preds_df(model, user_ids, orig_df=ratings_df_full, avg_ratings_df=comments_grouped_full_topic_cat, sys_eval_df=sys_eval_df, bins=BINS, topic=None, model_name=None):
 
 
 
 
 
831
  # Prep dataframe for all predictions we'd like to request
832
  start = time.time()
833
  sys_eval_comment_ids = sys_eval_df.item_id.unique().tolist()
@@ -861,9 +867,14 @@ def get_preds_df(model, user_ids, orig_df=ratings_df_full, avg_ratings_df=commen
861
 
862
  return df
863
 
 
 
 
 
 
864
  def train_user_model(ratings_df, train_df=train_df, model_eval_df=model_eval_df, train_frac=0.75, model_type="SVD", sim_type=None, user_based=True):
865
  # Sample from shuffled labeled dataframe and add batch to train set; specified set size to model_eval set
866
- labeled = ratings_df.sample(frac=1)
867
  batch_size = math.floor(len(labeled) * train_frac)
868
  labeled_train = labeled[:batch_size]
869
  labeled_model_eval = labeled[batch_size:]
@@ -876,6 +887,10 @@ def train_user_model(ratings_df, train_df=train_df, model_eval_df=model_eval_df,
876
 
877
  return model, perf, labeled_train, labeled_model_eval
878
 
 
 
 
 
879
  def train_model(train_df, model_eval_df, model_type="SVD", sim_type=None, user_based=True):
880
  # Train model
881
  reader = Reader(rating_scale=(0, 4))
@@ -1126,6 +1141,7 @@ def get_comment_url(row):
1126
  def get_topic_url(row):
1127
  return f"#{row['topic_']}/#topic"
1128
 
 
1129
  def plot_overall_vis(preds_df, error_type, cur_user, cur_model, n_topics=None, bins=VIS_BINS, threshold=TOXIC_THRESHOLD, bin_step=0.05):
1130
  df = preds_df.copy().reset_index()
1131
 
@@ -1242,22 +1258,15 @@ def plot_overall_vis(preds_df, error_type, cur_user, cur_model, n_topics=None, b
1242
 
1243
  return plot
1244
 
1245
- def get_cluster_overview_plot(preds_df, error_type, threshold=TOXIC_THRESHOLD, use_model=True):
1246
- preds_df_mod = preds_df.merge(comments_grouped_full_topic_cat, on="item_id", how="left", suffixes=('_', '_avg'))
1247
-
1248
- if use_model:
1249
- return plot_overall_vis_cluster(preds_df_mod, error_type=error_type, n_comments=500, threshold=threshold)
1250
- else:
1251
- return plot_overall_vis_cluster2(preds_df_mod, error_type=error_type, n_comments=500, threshold=threshold)
1252
-
1253
- def plot_overall_vis_cluster2(preds_df, error_type, n_comments=None, bins=VIS_BINS, threshold=TOXIC_THRESHOLD, bin_step=0.05):
1254
  df = preds_df.copy().reset_index()
1255
 
1256
  df["vis_pred_bin"], out_bins = pd.cut(df["rating"], bins, labels=VIS_BINS_LABELS, retbins=True)
1257
  df = df[df["user_id"] == "A"].sort_values(by=["rating"]).reset_index()
1258
  df["system_label"] = [("toxic" if r > threshold else "non-toxic") for r in df["rating"].tolist()]
1259
  df["key"] = [get_key_no_model(sys, threshold) for sys in df["rating"].tolist()]
1260
- print("len(df)", len(df)) # always 0 for some reason (from keyword search)
1261
  df["category"] = df.apply(lambda row: get_category(row), axis=1)
1262
  df["url"] = df.apply(lambda row: get_comment_url(row), axis=1)
1263
 
@@ -1345,17 +1354,15 @@ def plot_overall_vis_cluster2(preds_df, error_type, n_comments=None, bins=VIS_BI
1345
  final_plot = (bkgd + annotation + chart + rule).properties(height=(plot_dim_height), width=plot_dim_width).resolve_scale(color='independent').to_json()
1346
 
1347
  return final_plot, df
1348
-
 
1349
  def plot_overall_vis_cluster(preds_df, error_type, n_comments=None, bins=VIS_BINS, threshold=TOXIC_THRESHOLD, bin_step=0.05):
1350
  df = preds_df.copy().reset_index(drop=True)
1351
- # df = df[df["topic_"] == topic]
1352
 
1353
  df["vis_pred_bin"], out_bins = pd.cut(df["pred"], bins, labels=VIS_BINS_LABELS, retbins=True)
1354
  df = df[df["user_id"] == "A"].sort_values(by=["rating"]).reset_index(drop=True)
1355
  df["system_label"] = [("toxic" if r > threshold else "non-toxic") for r in df["rating"].tolist()]
1356
  df["key"] = [get_key(sys, user, threshold) for sys, user in zip(df["rating"].tolist(), df["pred"].tolist())]
1357
- print("len(df)", len(df)) # always 0 for some reason (from keyword search)
1358
- # print("columns", df.columns)
1359
  df["category"] = df.apply(lambda row: get_category(row), axis=1)
1360
  df["url"] = df.apply(lambda row: get_comment_url(row), axis=1)
1361
 
 
431
 
432
  return combined
433
 
434
+ # Generates the summary plot across all topics for the user
435
  def show_overall_perf(variant, error_type, cur_user, threshold=TOXIC_THRESHOLD, breakdown_axis=None, topic_vis_method="median"):
436
  # Your perf (calculate using model and testset)
437
  breakdown_axis = readable_to_internal[breakdown_axis]
 
448
  topic_overview_plot_json = json.load(f)
449
  else:
450
  preds_df_mod = preds_df.merge(comments_grouped_full_topic_cat, on="item_id", how="left", suffixes=('_', '_avg'))
451
+ if topic_vis_method == "median": # Default
452
  preds_df_mod_grp = preds_df_mod.groupby(["topic_", "user_id"]).median()
453
  elif topic_vis_method == "mean":
454
  preds_df_mod_grp = preds_df_mod.groupby(["topic_", "user_id"]).mean()
 
738
 
739
  mae, mse, rmse, avg_diff = user_perf_metrics[model_name]
740
 
741
+ cur_preds_df = get_preds_df(cur_model, ["A"], sys_eval_df=ratings_df_full) # Just get results for user
742
 
743
  # Save this batch of labels
744
  with open(os.path.join(module_dir, label_dir, f"{last_label_i + 1}.pkl"), "wb") as f:
 
828
  user_item_preds[(uid, iid)] = est
829
  return user_item_preds
830
 
831
+ # Pre-computes predictions for the provided model and specified users on the system-eval dataset
832
+ # - model: trained model
833
+ # - user_ids: list of user IDs to compute predictions for
834
+ # - avg_ratings_df: dataframe of average ratings for each comment (pre-computed)
835
+ # - sys_eval_df: dataframe of system eval labels (pre-computed)
836
+ def get_preds_df(model, user_ids, avg_ratings_df=comments_grouped_full_topic_cat, sys_eval_df=sys_eval_df, bins=BINS):
837
  # Prep dataframe for all predictions we'd like to request
838
  start = time.time()
839
  sys_eval_comment_ids = sys_eval_df.item_id.unique().tolist()
 
867
 
868
  return df
869
 
870
+ # Given the full set of ratings, trains the specified model type and evaluates on the model eval set
871
+ # - ratings_df: dataframe of all ratings
872
+ # - train_df: dataframe of training labels
873
+ # - model_eval_df: dataframe of model eval labels (validation set)
874
+ # - train_frac: fraction of ratings to use for training
875
  def train_user_model(ratings_df, train_df=train_df, model_eval_df=model_eval_df, train_frac=0.75, model_type="SVD", sim_type=None, user_based=True):
876
  # Sample from shuffled labeled dataframe and add batch to train set; specified set size to model_eval set
877
+ labeled = ratings_df.sample(frac=1) # Shuffle the data
878
  batch_size = math.floor(len(labeled) * train_frac)
879
  labeled_train = labeled[:batch_size]
880
  labeled_model_eval = labeled[batch_size:]
 
887
 
888
  return model, perf, labeled_train, labeled_model_eval
889
 
890
+ # Given a set of labels split into training and validation (model_eval), trains the specified model type on the training labels and evaluates on the model_eval labels
891
+ # - train_df: dataframe of training labels
892
+ # - model_eval_df: dataframe of model eval labels (validation set)
893
+ # - model_type: type of model to train
894
  def train_model(train_df, model_eval_df, model_type="SVD", sim_type=None, user_based=True):
895
  # Train model
896
  reader = Reader(rating_scale=(0, 4))
 
1141
  def get_topic_url(row):
1142
  return f"#{row['topic_']}/#topic"
1143
 
1144
+ # Plots overall results histogram (each block is a topic)
1145
  def plot_overall_vis(preds_df, error_type, cur_user, cur_model, n_topics=None, bins=VIS_BINS, threshold=TOXIC_THRESHOLD, bin_step=0.05):
1146
  df = preds_df.copy().reset_index()
1147
 
 
1258
 
1259
  return plot
1260
 
1261
+ # Plots cluster results histogram (each block is a comment), but *without* a model
1262
+ # as a point of reference (in contrast to plot_overall_vis_cluster)
1263
+ def plot_overall_vis_cluster_no_model(preds_df, n_comments=None, bins=VIS_BINS, threshold=TOXIC_THRESHOLD, bin_step=0.05):
 
 
 
 
 
 
1264
  df = preds_df.copy().reset_index()
1265
 
1266
  df["vis_pred_bin"], out_bins = pd.cut(df["rating"], bins, labels=VIS_BINS_LABELS, retbins=True)
1267
  df = df[df["user_id"] == "A"].sort_values(by=["rating"]).reset_index()
1268
  df["system_label"] = [("toxic" if r > threshold else "non-toxic") for r in df["rating"].tolist()]
1269
  df["key"] = [get_key_no_model(sys, threshold) for sys in df["rating"].tolist()]
 
1270
  df["category"] = df.apply(lambda row: get_category(row), axis=1)
1271
  df["url"] = df.apply(lambda row: get_comment_url(row), axis=1)
1272
 
 
1354
  final_plot = (bkgd + annotation + chart + rule).properties(height=(plot_dim_height), width=plot_dim_width).resolve_scale(color='independent').to_json()
1355
 
1356
  return final_plot, df
1357
+
1358
+ # Plots cluster results histogram (each block is a comment) *with* a model as a point of reference
1359
  def plot_overall_vis_cluster(preds_df, error_type, n_comments=None, bins=VIS_BINS, threshold=TOXIC_THRESHOLD, bin_step=0.05):
1360
  df = preds_df.copy().reset_index(drop=True)
 
1361
 
1362
  df["vis_pred_bin"], out_bins = pd.cut(df["pred"], bins, labels=VIS_BINS_LABELS, retbins=True)
1363
  df = df[df["user_id"] == "A"].sort_values(by=["rating"]).reset_index(drop=True)
1364
  df["system_label"] = [("toxic" if r > threshold else "non-toxic") for r in df["rating"].tolist()]
1365
  df["key"] = [get_key(sys, user, threshold) for sys, user in zip(df["rating"].tolist(), df["pred"].tolist())]
 
 
1366
  df["category"] = df.apply(lambda row: get_category(row), axis=1)
1367
  df["url"] = df.apply(lambda row: get_comment_url(row), axis=1)
1368
 
server.py CHANGED
@@ -220,8 +220,14 @@ def get_cluster_results():
220
  if (scaffold_method == "personal_cluster") and (os.path.isfile(personal_cluster_file)):
221
  cluster_overview_plot_json, sampled_df = utils.plot_overall_vis_cluster(topic_df, error_type=error_type, n_comments=500)
222
  else:
223
- # Regular
224
- cluster_overview_plot_json, sampled_df = utils.get_cluster_overview_plot(topic_df, error_type=error_type, use_model=use_model)
 
 
 
 
 
 
225
 
226
  cluster_comments = utils.get_cluster_comments(sampled_df,error_type=error_type, num_examples=n_examples, use_model=use_model) # New version of cluster comment table
227
 
 
220
  if (scaffold_method == "personal_cluster") and (os.path.isfile(personal_cluster_file)):
221
  cluster_overview_plot_json, sampled_df = utils.plot_overall_vis_cluster(topic_df, error_type=error_type, n_comments=500)
222
  else:
223
+ # Default case
224
+ topic_df_mod = topic_df.merge(comments_grouped_full_topic_cat, on="item_id", how="left", suffixes=('_', '_avg'))
225
+ if use_model:
226
+ # Display results with the model as a reference point
227
+ cluster_overview_plot_json, sampled_df = utils.plot_overall_vis_cluster(topic_df_mod, error_type=error_type, n_comments=500)
228
+ else:
229
+ # Display results without a model
230
+ cluster_overview_plot_json, sampled_df = utils.plot_overall_vis_cluster_no_model(topic_df_mod, n_comments=500)
231
 
232
  cluster_comments = utils.get_cluster_comments(sampled_df,error_type=error_type, num_examples=n_examples, use_model=use_model) # New version of cluster comment table
233