Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import polars as pl | |
from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
import torch | |
import spaces | |
from torch import nn | |
from transformers import AutoModel, AutoTokenizer, AutoConfig | |
from huggingface_hub import PyTorchModelHubMixin | |
import pandas as pd | |
class QualityModel(nn.Module, PyTorchModelHubMixin): | |
def __init__(self, config): | |
super(QualityModel, self).__init__() | |
self.model = AutoModel.from_pretrained(config["base_model"]) | |
self.dropout = nn.Dropout(config["fc_dropout"]) | |
self.fc = nn.Linear(self.model.config.hidden_size, len(config["id2label"])) | |
def forward(self, input_ids, attention_mask): | |
features = self.model( | |
input_ids=input_ids, attention_mask=attention_mask | |
).last_hidden_state | |
dropped = self.dropout(features) | |
outputs = self.fc(dropped) | |
return torch.softmax(outputs[:, 0, :], dim=1) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
config = AutoConfig.from_pretrained("nvidia/quality-classifier-deberta") | |
tokenizer = AutoTokenizer.from_pretrained("nvidia/quality-classifier-deberta") | |
model = QualityModel.from_pretrained("nvidia/quality-classifier-deberta").to(device) | |
model.eval() | |
def predict(texts: list[str]): | |
inputs = tokenizer( | |
texts, return_tensors="pt", padding="longest", truncation=True | |
).to(device) | |
outputs = model(inputs["input_ids"], inputs["attention_mask"]) | |
predicted_classes = torch.argmax(outputs, dim=1) | |
predicted_domains = [ | |
config.id2label[class_idx.item()] for class_idx in predicted_classes.cpu().numpy() | |
] | |
return predicted_domains | |
def plot_and_df(texts, preds): | |
texts_df = pd.DataFrame({"quality": preds, "text": texts}) | |
counts = pd.DataFrame({"quality": preds}).value_counts().to_frame() | |
counts.reset_index(inplace=True) | |
return ( | |
gr.BarPlot(counts, x="quality", y="count"), | |
texts_df[texts_df["quality"] == "Low"][:20], | |
texts_df[texts_df["quality"] == "Medium"][:20], | |
texts_df[texts_df["quality"] == "High"][:20], | |
) | |
def run_quality_check(dataset, column, batch_size, num_examples): | |
config = "default" | |
data = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{config}/train/0000.parquet", columns=[column]) | |
texts = data[column].to_list() | |
# batch_size = 100 | |
predictions, texts_processed = [], [] | |
for i in range(0, min(len(texts), num_examples), batch_size): | |
batch_texts = texts[i:i+batch_size] | |
batch_predictions = predict(batch_texts) | |
predictions.extend(batch_predictions) | |
texts_processed.extend(batch_texts) | |
yield plot_and_df(texts_processed, predictions) | |
with gr.Blocks() as demo: | |
gr.Markdown("# π« Dataset Quality Checker π«") | |
dataset_name = HuggingfaceHubSearch( | |
label="Hub Dataset ID", | |
placeholder="Search for dataset id on Huggingface", | |
search_type="dataset", | |
value="fka/awesome-chatgpt-prompts", | |
) | |
# config_name = "default" # TODO: user input | |
def embed(name): | |
html_code = f""" | |
<iframe | |
src="https://huggingface.co/datasets/{name}/embed/viewer/default/train" | |
frameborder="0" | |
width="100%" | |
height="700px" | |
></iframe> | |
""" | |
return gr.HTML(value=html_code) | |
text_column = gr.Textbox(placeholder="text", label="Text colum name to check (data must be non-nested, raw texts!)") | |
batch_size = gr.Number(100, label="Batch size") | |
num_examples = gr.Number(1000, label="Num examples to check") | |
gr_check_btn = gr.Button("Check Dataset") | |
plot = gr.BarPlot() | |
with gr.Accordion("Explore some individual examples for each class", open=False): | |
df_low, df_medium, df_high = gr.DataFrame(), gr.DataFrame(), gr.DataFrame() | |
gr_check_btn.click(run_quality_check, inputs=[dataset_name, text_column, batch_size, num_examples], outputs=[plot, df_low, df_medium, df_high]) | |
demo.launch() |