polinaeterna HF staff commited on
Commit
3bb5a93
β€’
1 Parent(s): 3dcef48

ad toxicity check

Browse files
Files changed (1) hide show
  1. app.py +87 -8
app.py CHANGED
@@ -6,7 +6,6 @@ import multiprocessing
6
  import gradio as gr
7
  import pandas as pd
8
  import polars as pl
9
- import numpy as np
10
  import matplotlib.pyplot as plt
11
  import spaces
12
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
@@ -90,12 +89,83 @@ def plot_and_df(texts, preds):
90
  )
91
 
92
 
93
- @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  def run_quality_check(dataset, column, batch_size, num_examples):
95
- # config = "default"
96
  info_resp = session.get(f"https://datasets-server.huggingface.co/info?dataset={dataset}", timeout=3).json()
97
  if "error" in info_resp:
98
- yield "❌ " + info_resp["error"], gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), plt.Figure()
99
  return
100
  config = "default" if "default" in info_resp["dataset_info"] else next(iter(info_resp["dataset_info"]))
101
  split = "train" if "train" in info_resp["dataset_info"][config]["splits"] else next(
@@ -106,9 +176,10 @@ def run_quality_check(dataset, column, batch_size, num_examples):
106
  try:
107
  data = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{config}/partial-{split}/0000.parquet", columns=[column])
108
  except Exception as error:
109
- yield f"❌ {error}", gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), plt.Figure()
110
  return
111
  texts = data[column].to_list()
 
112
  # batch_size = 100
113
  predictions, texts_processed = [], []
114
  num_examples = min(len(texts), num_examples)
@@ -117,7 +188,7 @@ def run_quality_check(dataset, column, batch_size, num_examples):
117
  batch_predictions = predict(batch_texts)
118
  predictions.extend(batch_predictions)
119
  texts_processed.extend(batch_texts)
120
- yield {"check in progress...": (i+batch_size) / num_examples}, *plot_and_df(texts_processed, predictions), plt.Figure()
121
 
122
  with multiprocessing.Pool(processes=8) as pool:
123
  props = pool.map(proportion_non_ascii, texts)
@@ -128,7 +199,8 @@ def run_quality_check(dataset, column, batch_size, num_examples):
128
  plt.xlabel('Proportion of non-ASCII characters')
129
  plt.ylabel('Number of texts')
130
 
131
- yield {"finished": 1.}, *plot_and_df(texts_processed, predictions), plt.gcf()
 
132
 
133
  with gr.Blocks() as demo:
134
  gr.Markdown(
@@ -175,6 +247,13 @@ with gr.Blocks() as demo:
175
 
176
  # non_ascii_hist = gr.DataFrame(visible=False)
177
  non_ascii_hist = gr.Plot()
178
- gr_check_btn.click(run_quality_check, inputs=[dataset_name, text_column, batch_size, num_examples], outputs=[progress_bar, plot, df_low, df_medium, df_high, non_ascii_hist])
 
 
 
 
 
 
 
179
 
180
  demo.launch()
 
6
  import gradio as gr
7
  import pandas as pd
8
  import polars as pl
 
9
  import matplotlib.pyplot as plt
10
  import spaces
11
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
 
89
  )
90
 
91
 
92
+ PERSPECTIVE_API_KEY = os.environ.get("PERSPECTIVE_API_KEY")
93
+ PERSPECTIVE_URL = f"https://commentanalyzer.googleapis.com/v1alpha1/comments:analyze?key={PERSPECTIVE_API_KEY}"
94
+ REQUESTED_ATTRIBUTES = {"TOXICITY": {}, "SEVERE_TOXICITY": {},
95
+ "IDENTITY_ATTACK": {}, "INSULT": {}, "PROFANITY": {},
96
+ "THREAT": {}}
97
+ ATT_SCORE = "attributeScores"
98
+ SUM_SCORE = "summaryScore"
99
+
100
+
101
+ def plot_toxicity(scores):
102
+ fig, axs = plt.subplots(2, 3)#, figsize=(10, 6))
103
+ for x, y, score_name in zip([0,0,0,1,1,1], [0,1,2,0,1,2], scores):
104
+ axs[x,y].hist(scores[score_name], bins=20, range=(0., 1.))
105
+ # axs[x,y].set_title(f'Histogram of {score_name}')
106
+ axs[x,y].set_xlabel(f'{score_name}')
107
+ # axs[x,y].set_ylabel('Number of texts')
108
+ fig.supylabel("Number of texts")
109
+ fig.suptitle("Histogram of toxicity scores")
110
+ fig.tight_layout()
111
+
112
+ return fig
113
+
114
+ def call_perspective_api(texts_df, column_name):#, s):
115
+ headers = {
116
+ "content-type": "application/json",
117
+ }
118
+ req_att_scores = {attr: [] for attr in REQUESTED_ATTRIBUTES}
119
+
120
+ texts = texts_df[column_name].values
121
+ for i, text in tqdm(enumerate(texts), desc="scanning with perspective"):
122
+ data = {
123
+ "comment": {"text": text},
124
+ "languages": ["en"],
125
+ "requestedAttributes": REQUESTED_ATTRIBUTES
126
+ }
127
+ time.sleep(1)
128
+ try:
129
+ req_response = requests.post(PERSPECTIVE_URL, json=data, headers=headers)
130
+ except Exception as e:
131
+ print(e)
132
+ return req_att_scores
133
+
134
+ if req_response.ok:
135
+ response = req_response.json()
136
+ # logger.info("Perspective API response is:")
137
+ # logger.info(response)
138
+ if ATT_SCORE in response:
139
+ for req_att in REQUESTED_ATTRIBUTES:
140
+ if req_att in response[ATT_SCORE]:
141
+ att_score = response[ATT_SCORE][req_att][SUM_SCORE]["value"]
142
+ req_att_scores[req_att].append(att_score)
143
+ else:
144
+ req_att_scores[req_att].append(0)
145
+ else:
146
+ # logger.error(
147
+ # "Unexpected response format from Perspective API."
148
+ # )
149
+ raise ValueError(req_response)
150
+ else:
151
+ try:
152
+ req_response.raise_for_status()
153
+ except Exception as e:
154
+ print(e)
155
+ return req_att_scores
156
+ if i % 10 == 0:
157
+ plot_toxicity(req_att_scores)
158
+ yield plt.gcf(), pd.DataFrame()
159
+
160
+ plot_toxicity(req_att_scores)
161
+ yield plt.gcf(), pd.DataFrame.from_dict({column_name: texts, **req_att_scores})
162
+
163
+
164
+ # @spaces.GPU
165
  def run_quality_check(dataset, column, batch_size, num_examples):
 
166
  info_resp = session.get(f"https://datasets-server.huggingface.co/info?dataset={dataset}", timeout=3).json()
167
  if "error" in info_resp:
168
+ yield "❌ " + info_resp["error"], gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), plt.Figure(), pd.DataFrame(),
169
  return
170
  config = "default" if "default" in info_resp["dataset_info"] else next(iter(info_resp["dataset_info"]))
171
  split = "train" if "train" in info_resp["dataset_info"][config]["splits"] else next(
 
176
  try:
177
  data = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{config}/partial-{split}/0000.parquet", columns=[column])
178
  except Exception as error:
179
+ yield f"❌ {error}", gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), plt.Figure(), pd.DataFrame(),
180
  return
181
  texts = data[column].to_list()
182
+ texts_sample = data.sample(20, shuffle=True, seed=16).to_pandas()
183
  # batch_size = 100
184
  predictions, texts_processed = [], []
185
  num_examples = min(len(texts), num_examples)
 
188
  batch_predictions = predict(batch_texts)
189
  predictions.extend(batch_predictions)
190
  texts_processed.extend(batch_texts)
191
+ yield {"check in progress...": (i+batch_size) / num_examples}, *plot_and_df(texts_processed, predictions), plt.Figure(), pd.DataFrame()
192
 
193
  with multiprocessing.Pool(processes=8) as pool:
194
  props = pool.map(proportion_non_ascii, texts)
 
199
  plt.xlabel('Proportion of non-ASCII characters')
200
  plt.ylabel('Number of texts')
201
 
202
+ yield {"finished": 1.}, *plot_and_df(texts_processed, predictions), plt.gcf(), texts_sample
203
+
204
 
205
  with gr.Blocks() as demo:
206
  gr.Markdown(
 
247
 
248
  # non_ascii_hist = gr.DataFrame(visible=False)
249
  non_ascii_hist = gr.Plot()
250
+ texts_sample_df = gr.DataFrame(visible=False)
251
+ gr_check_btn.click(run_quality_check, inputs=[dataset_name, text_column, batch_size, num_examples], outputs=[progress_bar, plot, df_low, df_medium, df_high, non_ascii_hist, texts_sample_df])
252
+
253
+ gr_toxicity_btn = gr.Button("Run perpspective API to check toxicity of random samples.")
254
+ toxicity_hist = gr.Plot()
255
+ with gr.Accordion("Explore examples with toxicity scores:", open=False):
256
+ toxicity_df = gr.DataFrame()
257
+ gr_toxicity_btn.click(call_perspective_api, inputs=[texts_sample_df, text_column], outputs=[toxicity_hist, toxicity_df])
258
 
259
  demo.launch()