polinaeterna's picture
polinaeterna HF staff
set max examples manually
e5960a0
raw
history blame
4.07 kB
import gradio as gr
import polars as pl
from gradio_huggingfacehub_search import HuggingfaceHubSearch
import torch
import spaces
from torch import nn
from transformers import AutoModel, AutoTokenizer, AutoConfig
from huggingface_hub import PyTorchModelHubMixin
import pandas as pd
class QualityModel(nn.Module, PyTorchModelHubMixin):
def __init__(self, config):
super(QualityModel, self).__init__()
self.model = AutoModel.from_pretrained(config["base_model"])
self.dropout = nn.Dropout(config["fc_dropout"])
self.fc = nn.Linear(self.model.config.hidden_size, len(config["id2label"]))
def forward(self, input_ids, attention_mask):
features = self.model(
input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state
dropped = self.dropout(features)
outputs = self.fc(dropped)
return torch.softmax(outputs[:, 0, :], dim=1)
device = "cuda" if torch.cuda.is_available() else "cpu"
config = AutoConfig.from_pretrained("nvidia/quality-classifier-deberta")
tokenizer = AutoTokenizer.from_pretrained("nvidia/quality-classifier-deberta")
model = QualityModel.from_pretrained("nvidia/quality-classifier-deberta").to(device)
model.eval()
@spaces.GPU
def predict(texts: list[str]):
inputs = tokenizer(
texts, return_tensors="pt", padding="longest", truncation=True
).to(device)
outputs = model(inputs["input_ids"], inputs["attention_mask"])
predicted_classes = torch.argmax(outputs, dim=1)
predicted_domains = [
config.id2label[class_idx.item()] for class_idx in predicted_classes.cpu().numpy()
]
return predicted_domains
def plot_and_df(texts, preds):
texts_df = pd.DataFrame({"quality": preds, "text": texts})
counts = pd.DataFrame({"quality": preds}).value_counts().to_frame()
counts.reset_index(inplace=True)
return (
gr.BarPlot(counts, x="quality", y="count"),
texts_df[texts_df["quality"] == "Low"][:20],
texts_df[texts_df["quality"] == "Medium"][:20],
texts_df[texts_df["quality"] == "High"][:20],
)
def run_quality_check(dataset, column, batch_size, num_examples):
config = "default"
data = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{config}/train/0000.parquet", columns=[column])
texts = data[column].to_list()
# batch_size = 100
predictions, texts_processed = [], []
for i in range(0, min(len(texts), num_examples), batch_size):
batch_texts = texts[i:i+batch_size]
batch_predictions = predict(batch_texts)
predictions.extend(batch_predictions)
texts_processed.extend(batch_texts)
yield plot_and_df(texts_processed, predictions)
with gr.Blocks() as demo:
gr.Markdown("# πŸ’« Dataset Quality Checker πŸ’«")
dataset_name = HuggingfaceHubSearch(
label="Hub Dataset ID",
placeholder="Search for dataset id on Huggingface",
search_type="dataset",
value="fka/awesome-chatgpt-prompts",
)
# config_name = "default" # TODO: user input
@gr.render(inputs=dataset_name)
def embed(name):
html_code = f"""
<iframe
src="https://huggingface.co/datasets/{name}/embed/viewer/default/train"
frameborder="0"
width="100%"
height="700px"
></iframe>
"""
return gr.HTML(value=html_code)
text_column = gr.Textbox(placeholder="text", label="Text colum name to check (data must be non-nested, raw texts!)")
batch_size = gr.Number(100, label="Batch size")
num_examples = gr.Number(1000, label="Num examples to check")
gr_check_btn = gr.Button("Check Dataset")
plot = gr.BarPlot()
with gr.Accordion("Explore some individual examples for each class", open=False):
df_low, df_medium, df_high = gr.DataFrame(), gr.DataFrame(), gr.DataFrame()
gr_check_btn.click(run_quality_check, inputs=[dataset_name, text_column, batch_size, num_examples], outputs=[plot, df_low, df_medium, df_high])
demo.launch()