Spaces:
Running
on
Zero
Running
on
Zero
import requests | |
from collections import Counter | |
from requests.adapters import HTTPAdapter, Retry | |
import multiprocessing | |
import os | |
import time | |
import logging | |
import gradio as gr | |
import pandas as pd | |
import polars as pl | |
import matplotlib.pyplot as plt | |
import spaces | |
from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
from huggingface_hub import PyTorchModelHubMixin | |
import torch | |
from torch import nn | |
from transformers import AutoModel, AutoTokenizer, AutoConfig | |
from tqdm import tqdm | |
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s") | |
session = requests.Session() | |
retries = Retry(total=5, backoff_factor=1, status_forcelist=[502, 503, 504]) | |
session.mount('http://', HTTPAdapter(max_retries=retries)) | |
class QualityModel(nn.Module, PyTorchModelHubMixin): | |
def __init__(self, config): | |
super(QualityModel, self).__init__() | |
self.model = AutoModel.from_pretrained(config["base_model"]) | |
self.dropout = nn.Dropout(config["fc_dropout"]) | |
self.fc = nn.Linear(self.model.config.hidden_size, len(config["id2label"])) | |
def forward(self, input_ids, attention_mask): | |
features = self.model( | |
input_ids=input_ids, attention_mask=attention_mask | |
).last_hidden_state | |
dropped = self.dropout(features) | |
outputs = self.fc(dropped) | |
return torch.softmax(outputs[:, 0, :], dim=1) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
config = AutoConfig.from_pretrained("nvidia/quality-classifier-deberta") | |
tokenizer = AutoTokenizer.from_pretrained("nvidia/quality-classifier-deberta") | |
model = QualityModel.from_pretrained("nvidia/quality-classifier-deberta").to(device) | |
# model = torch.compile(model) | |
model.eval() | |
def predict(texts: list[str]): | |
inputs = tokenizer( | |
texts, return_tensors="pt", padding="longest", truncation=True | |
).to(device) | |
outputs = model(inputs["input_ids"], inputs["attention_mask"]) | |
predicted_classes = torch.argmax(outputs, dim=1) | |
predicted_domains = [ | |
config.id2label[class_idx.item()] for class_idx in predicted_classes.cpu().numpy() | |
] | |
return predicted_domains | |
def plot_and_df(texts, preds): | |
texts_df = pd.DataFrame({"quality": preds, "text": texts}) | |
counts = Counter(preds) | |
counts_df = pd.DataFrame( | |
{ | |
"quality": ["Low", "Medium", "High"], | |
"count": [counts.get("Low", 0), counts.get("Medium", 0), counts.get("High", 0)] | |
} | |
) | |
# counts.reset_index(inplace=True) | |
return ( | |
gr.BarPlot(counts_df, x="quality", y="count"), | |
texts_df[texts_df["quality"] == "Low"][["text"]][:20], | |
texts_df[texts_df["quality"] == "Medium"][["text"]][:20], | |
texts_df[texts_df["quality"] == "High"][["text"]][:20], | |
) | |
def run_quality_check(dataset, column, batch_size, num_examples): | |
info_resp = session.get(f"https://datasets-server.huggingface.co/info?dataset={dataset}", timeout=3).json() | |
if "error" in info_resp: | |
yield "β " + info_resp["error"], gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), | |
return | |
config = "default" if "default" in info_resp["dataset_info"] else next(iter(info_resp["dataset_info"])) | |
split = "train" if "train" in info_resp["dataset_info"][config]["splits"] else next( | |
iter(info_resp["dataset_info"][config]["splits"])) | |
try: | |
data = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{config}/{split}/0000.parquet", columns=[column]) | |
except pl.exceptions.ComputeError: | |
try: | |
data = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{config}/partial-{split}/0000.parquet", columns=[column]) | |
except pl.exceptions.ComputeError: | |
try: | |
data = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{config}/{split}-part0/0000.parquet", columns=[column]) | |
except Exception as error: | |
yield f"β {error}", gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), | |
return | |
texts = [text[:10000] for text in data[column].to_list()] | |
# texts_sample = data.sample(100, shuffle=True, seed=16).to_pandas() | |
# batch_size = 100 | |
predictions, texts_processed = [], [] | |
num_examples = min(len(texts), num_examples) | |
for i in range(0, num_examples, batch_size): | |
batch_texts = texts[i:i+batch_size] | |
batch_predictions = predict(batch_texts) | |
predictions.extend(batch_predictions) | |
texts_processed.extend(batch_texts) | |
yield {"check in progress...": i / num_examples}, *plot_and_df(texts_processed, predictions), pd.DataFrame() | |
# with multiprocessing.Pool(processes=8) as pool: | |
# props = pool.map(proportion_non_ascii, texts) | |
# | |
# # non_ascii_df = pd.DataFrame.from_dict({"prop_non_ascii": props, "text": texts}) | |
# plt.hist(props, bins=20, range=(0., 1.)) | |
# plt.title('Histogram of proportion of non-ASCII characters') | |
# plt.xlabel('Proportion of non-ASCII characters') | |
# plt.ylabel('Number of texts') | |
yield {"finished": 1.}, *plot_and_df(texts_processed, predictions), data | |
PERSPECTIVE_API_KEY = os.environ.get("PERSPECTIVE_API_KEY") | |
PERSPECTIVE_URL = f"https://commentanalyzer.googleapis.com/v1alpha1/comments:analyze?key={PERSPECTIVE_API_KEY}" | |
REQUESTED_ATTRIBUTES = {"TOXICITY": {}, "SEVERE_TOXICITY": {}, | |
"IDENTITY_ATTACK": {}, "INSULT": {}, "PROFANITY": {}, | |
"THREAT": {}} | |
ATT_SCORE = "attributeScores" | |
SUM_SCORE = "summaryScore" | |
def plot_toxicity(scores): | |
fig, axs = plt.subplots(2, 3)#, figsize=(10, 6)) | |
for x, y, score_name in zip([0,0,0,1,1,1], [0,1,2,0,1,2], scores): | |
axs[x,y].hist(scores[score_name], bins=20, range=(0., 1.)) | |
axs[x,y].set_xlabel(score_name) | |
fig.supylabel("Number of texts") | |
fig.suptitle("Histogram of toxicity scores") | |
fig.tight_layout() | |
return fig | |
def call_perspective_api(texts_df, column_name, full_check=False): | |
headers = { | |
"content-type": "application/json", | |
} | |
req_att_scores = {attr: [] for attr in REQUESTED_ATTRIBUTES} | |
texts = texts_df.sample(100, random_state=16)[column_name].values if not full_check else texts_df[column_name].values | |
n_samples = len(texts) | |
for i, text in tqdm(enumerate(texts), desc="scanning with perspective"): | |
data = { | |
"comment": {"text": text}, | |
"languages": ["en"], | |
"requestedAttributes": REQUESTED_ATTRIBUTES | |
} | |
time.sleep(1) | |
try: | |
req_response = requests.post(PERSPECTIVE_URL, json=data, headers=headers) | |
except Exception as e: | |
print(e) | |
return req_att_scores | |
if req_response.ok: | |
response = req_response.json() | |
# logger.info("Perspective API response is:") | |
# logger.info(response) | |
if ATT_SCORE in response: | |
for req_att in REQUESTED_ATTRIBUTES: | |
if req_att in response[ATT_SCORE]: | |
att_score = response[ATT_SCORE][req_att][SUM_SCORE]["value"] | |
req_att_scores[req_att].append(att_score) | |
else: | |
req_att_scores[req_att].append(0) | |
else: | |
# logger.error( | |
# "Unexpected response format from Perspective API." | |
# ) | |
raise ValueError(req_response) | |
else: | |
try: | |
req_response.raise_for_status() | |
except Exception as e: | |
print(e) | |
return req_att_scores | |
if i % 10 == 0: | |
plot_toxicity(req_att_scores) | |
print(len(texts[:i]), len(req_att_scores["TOXICITY"])) | |
yield {"toxicity check in progress...": i / n_samples}, plt.gcf(), pd.DataFrame.from_dict({column_name: texts[:i+1], **req_att_scores}) | |
plot_toxicity(req_att_scores) | |
yield {"toxicity check finished.": 1.}, plt.gcf(), pd.DataFrame.from_dict({column_name: texts, **req_att_scores}) | |
def proportion_non_ascii(s): | |
""" | |
Compute the proportion of non-ASCII characters in a string. | |
Parameters: | |
s (str): The input string. | |
Returns: | |
float: The proportion of non-ASCII characters in the string. | |
""" | |
non_ascii_count = sum(1 for c in s if ord(c) > 127) | |
total_chars = len(s) | |
return non_ascii_count / total_chars if total_chars > 0 else 0.0 | |
def non_ascii_check(texts_df, column_name): | |
texts = texts_df[column_name].to_list() | |
with multiprocessing.Pool(processes=8) as pool: | |
props = pool.map(proportion_non_ascii, texts) | |
# non_ascii_df = pd.DataFrame.from_dict({"prop_non_ascii": props, "text": texts}) | |
plt.hist(props, bins=20, range=(0., 1.)) | |
plt.title('Histogram of proportion of non-ASCII characters') | |
plt.xlabel('Proportion of non-ASCII characters') | |
plt.ylabel('Number of texts') | |
return plt.gcf() | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# π« Dataset Quality Checker π« | |
Use [nvidia/quality-classifier-deberta](https://huggingface.co/nvidia/quality-classifier-deberta) on any text dataset on the Hub. | |
## Select dataset and text column | |
""" | |
) | |
dataset_name = HuggingfaceHubSearch( | |
label="Hub Dataset ID", | |
placeholder="Search for dataset id on Huggingface", | |
search_type="dataset", | |
# value="fka/awesome-chatgpt-prompts", | |
) | |
# config_name = "default" # TODO: user input | |
with gr.Accordion("Dataset preview", open=False): | |
def embed(name): | |
html_code = f""" | |
<iframe | |
src="https://huggingface.co/datasets/{name}/embed/viewer/default/train" | |
frameborder="0" | |
width="100%" | |
height="700px" | |
></iframe> | |
""" | |
return gr.HTML(value=html_code) | |
text_column = gr.Textbox(placeholder="text", label="Text colum name to check (data must be non-nested, raw texts!)") | |
gr.Markdown("## Run nvidia quality classifier") | |
batch_size = gr.Slider(0, 128, 32, step=8, label="Inference batch size (set this to smaller value if this space crashes.)") | |
num_examples = gr.Number(500, label="Number of first examples to check") | |
gr_check_btn = gr.Button("Check Dataset") | |
progress_bar = gr.Label(show_label=False) | |
plot = gr.BarPlot() | |
with gr.Accordion("Explore some individual examples for each class", open=False): | |
gr.Markdown("### Low") | |
df_low = gr.DataFrame() | |
gr.Markdown("### Medium") | |
df_medium = gr.DataFrame() | |
gr.Markdown("### High") | |
df_high = gr.DataFrame() | |
texts_df = gr.DataFrame(visible=False) | |
gr_check_btn.click( | |
run_quality_check, | |
inputs=[dataset_name, text_column, batch_size, num_examples], | |
outputs=[progress_bar, plot, df_low, df_medium, df_high, texts_df] | |
) | |
gr.Markdown("""## Compute text quality measures | |
* proportion of non-ascii characters | |
* #TODO""") | |
gr_ascii_btn = gr.Button("Data measures") | |
non_ascii_hist = gr.Plot() | |
gr_ascii_btn.click(non_ascii_check, inputs=[texts_df, text_column], outputs=[non_ascii_hist]) | |
gr.Markdown("## Explore toxicity") | |
checkbox = gr.Checkbox(value=False, label="Run on full first parquet data (better not)") | |
gr_toxicity_btn = gr.Button("Run perpspective API to check toxicity of random samples.") | |
toxicity_progress_bar = gr.Label(show_label=False) | |
toxicity_hist = gr.Plot() | |
with gr.Accordion("Explore examples with toxicity scores:", open=False): | |
toxicity_df = gr.DataFrame() | |
gr_toxicity_btn.click( | |
call_perspective_api, | |
inputs=[texts_df, text_column, checkbox], | |
outputs=[toxicity_progress_bar, toxicity_hist, toxicity_df] | |
) | |
demo.launch() |