import gradio as gr import polars as pl from gradio_huggingfacehub_search import HuggingfaceHubSearch import torch import spaces from torch import nn from transformers import AutoModel, AutoTokenizer, AutoConfig from huggingface_hub import PyTorchModelHubMixin import pandas as pd class QualityModel(nn.Module, PyTorchModelHubMixin): def __init__(self, config): super(QualityModel, self).__init__() self.model = AutoModel.from_pretrained(config["base_model"]) self.dropout = nn.Dropout(config["fc_dropout"]) self.fc = nn.Linear(self.model.config.hidden_size, len(config["id2label"])) def forward(self, input_ids, attention_mask): features = self.model( input_ids=input_ids, attention_mask=attention_mask ).last_hidden_state dropped = self.dropout(features) outputs = self.fc(dropped) return torch.softmax(outputs[:, 0, :], dim=1) device = "cuda" if torch.cuda.is_available() else "cpu" config = AutoConfig.from_pretrained("nvidia/quality-classifier-deberta") tokenizer = AutoTokenizer.from_pretrained("nvidia/quality-classifier-deberta") model = QualityModel.from_pretrained("nvidia/quality-classifier-deberta").to(device) model.eval() @spaces.GPU def predict(texts: list[str]): inputs = tokenizer( texts, return_tensors="pt", padding="longest", truncation=True ).to(device) outputs = model(inputs["input_ids"], inputs["attention_mask"]) predicted_classes = torch.argmax(outputs, dim=1) predicted_domains = [ config.id2label[class_idx.item()] for class_idx in predicted_classes.cpu().numpy() ] return predicted_domains def plot_and_df(texts, preds): texts_df = pd.DataFrame({"quality": preds, "text": texts}) counts = pd.DataFrame({"quality": preds}).value_counts().to_frame() counts.reset_index(inplace=True) return ( gr.BarPlot(counts, x="quality", y="count"), texts_df[texts_df["quality"] == "Low"][:20], texts_df[texts_df["quality"] == "Medium"][:20], texts_df[texts_df["quality"] == "High"][:20], ) def run_quality_check(dataset, column, batch_size, num_examples): config = "default" data = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{config}/train/0000.parquet", columns=[column]) texts = data[column].to_list() # batch_size = 100 predictions, texts_processed = [], [] for i in range(0, min(len(texts), num_examples), batch_size): batch_texts = texts[i:i+batch_size] batch_predictions = predict(batch_texts) predictions.extend(batch_predictions) texts_processed.extend(batch_texts) yield plot_and_df(texts_processed, predictions) with gr.Blocks() as demo: gr.Markdown("# 💫 Dataset Quality Checker 💫") dataset_name = HuggingfaceHubSearch( label="Hub Dataset ID", placeholder="Search for dataset id on Huggingface", search_type="dataset", value="fka/awesome-chatgpt-prompts", ) # config_name = "default" # TODO: user input @gr.render(inputs=dataset_name) def embed(name): html_code = f""" """ return gr.HTML(value=html_code) text_column = gr.Textbox(placeholder="text", label="Text colum name to check (data must be non-nested, raw texts!)") batch_size = gr.Number(100, label="Batch size") num_examples = gr.Number(1000, label="Num examples to check") gr_check_btn = gr.Button("Check Dataset") plot = gr.BarPlot() with gr.Accordion("Explore some individual examples for each class", open=False): df_low, df_medium, df_high = gr.DataFrame(), gr.DataFrame(), gr.DataFrame() gr_check_btn.click(run_quality_check, inputs=[dataset_name, text_column, batch_size, num_examples], outputs=[plot, df_low, df_medium, df_high]) demo.launch()