File size: 4,065 Bytes
12a4d67
 
 
 
5c68a09
12a4d67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bd0078
12a4d67
 
 
 
 
 
5c68a09
12a4d67
 
 
 
 
 
 
 
 
 
 
 
e1c0c70
 
 
406f549
3aad6e9
e1c0c70
 
 
 
 
 
 
e5960a0
e1c0c70
 
 
 
 
e5960a0
e1c0c70
 
 
 
 
12a4d67
e5960a0
12a4d67
 
8d6975b
12a4d67
 
 
 
 
2bd0078
12a4d67
 
 
 
2bd0078
12a4d67
 
 
 
 
 
 
e1c0c70
e5960a0
12a4d67
c0e4fc0
3aad6e9
 
 
e5960a0
8d6975b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
import polars as pl
from gradio_huggingfacehub_search import HuggingfaceHubSearch
import torch
import spaces
from torch import nn
from transformers import AutoModel, AutoTokenizer, AutoConfig
from huggingface_hub import PyTorchModelHubMixin
import pandas as pd


class QualityModel(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config):
        super(QualityModel, self).__init__()
        self.model = AutoModel.from_pretrained(config["base_model"])
        self.dropout = nn.Dropout(config["fc_dropout"])
        self.fc = nn.Linear(self.model.config.hidden_size, len(config["id2label"]))

    def forward(self, input_ids, attention_mask):
        features = self.model(
            input_ids=input_ids, attention_mask=attention_mask
        ).last_hidden_state
        dropped = self.dropout(features)
        outputs = self.fc(dropped)
        return torch.softmax(outputs[:, 0, :], dim=1)

device = "cuda" if torch.cuda.is_available() else "cpu"
config = AutoConfig.from_pretrained("nvidia/quality-classifier-deberta")
tokenizer = AutoTokenizer.from_pretrained("nvidia/quality-classifier-deberta")
model = QualityModel.from_pretrained("nvidia/quality-classifier-deberta").to(device)
model.eval()


@spaces.GPU
def predict(texts: list[str]):
    inputs = tokenizer(
        texts, return_tensors="pt", padding="longest", truncation=True
    ).to(device)
    outputs = model(inputs["input_ids"], inputs["attention_mask"])
    predicted_classes = torch.argmax(outputs, dim=1)
    predicted_domains = [
        config.id2label[class_idx.item()] for class_idx in predicted_classes.cpu().numpy()
    ]
    return predicted_domains


def plot_and_df(texts, preds):
    texts_df = pd.DataFrame({"quality": preds, "text": texts})
    counts = pd.DataFrame({"quality": preds}).value_counts().to_frame()
    counts.reset_index(inplace=True)
    return (
            gr.BarPlot(counts, x="quality", y="count"),
            texts_df[texts_df["quality"] == "Low"][:20],
            texts_df[texts_df["quality"] == "Medium"][:20],
            texts_df[texts_df["quality"] == "High"][:20],
        )


def run_quality_check(dataset, column, batch_size, num_examples):
    config = "default"
    data = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{config}/train/0000.parquet", columns=[column])
    texts = data[column].to_list()
    # batch_size = 100
    predictions, texts_processed = [], []
    for i in range(0, min(len(texts), num_examples), batch_size):
        batch_texts = texts[i:i+batch_size]
        batch_predictions = predict(batch_texts)
        predictions.extend(batch_predictions)
        texts_processed.extend(batch_texts)
        yield plot_and_df(texts_processed, predictions)


with gr.Blocks() as demo:
    gr.Markdown("# 💫 Dataset Quality Checker 💫")
    dataset_name = HuggingfaceHubSearch(
            label="Hub Dataset ID",
            placeholder="Search for dataset id on Huggingface",
            search_type="dataset",
            value="fka/awesome-chatgpt-prompts",
        )
    # config_name = "default"  # TODO: user input
    @gr.render(inputs=dataset_name)
    def embed(name):
        html_code = f"""
        <iframe
          src="https://huggingface.co/datasets/{name}/embed/viewer/default/train"
          frameborder="0"
          width="100%"
          height="700px"
        ></iframe>
            """
        return gr.HTML(value=html_code)
    text_column = gr.Textbox(placeholder="text", label="Text colum name to check (data must be non-nested, raw texts!)")
    batch_size = gr.Number(100, label="Batch size")
    num_examples = gr.Number(1000, label="Num examples to check")
    gr_check_btn = gr.Button("Check Dataset")
    plot = gr.BarPlot()

    with gr.Accordion("Explore some individual examples for each class", open=False):
        df_low, df_medium, df_high = gr.DataFrame(), gr.DataFrame(), gr.DataFrame()
    gr_check_btn.click(run_quality_check, inputs=[dataset_name, text_column, batch_size, num_examples], outputs=[plot, df_low, df_medium, df_high])

demo.launch()