polinaeterna commited on
Commit
44cbba4
·
1 Parent(s): 7badbdb

get config and split with api, include partial datasets

Browse files
Files changed (1) hide show
  1. app.py +31 -8
app.py CHANGED
@@ -1,13 +1,22 @@
 
 
 
 
1
  import gradio as gr
 
2
  import polars as pl
 
3
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
 
4
  import torch
5
- import spaces
6
  from torch import nn
7
  from transformers import AutoModel, AutoTokenizer, AutoConfig
8
- from huggingface_hub import PyTorchModelHubMixin
9
- import pandas as pd
10
- from collections import Counter
 
 
 
11
 
12
 
13
  class QualityModel(nn.Module, PyTorchModelHubMixin):
@@ -64,8 +73,22 @@ def plot_and_df(texts, preds):
64
 
65
 
66
  def run_quality_check(dataset, column, batch_size, num_examples):
67
- config = "default"
68
- data = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{config}/train/0000.parquet", columns=[column])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  texts = data[column].to_list()
70
  # batch_size = 100
71
  predictions, texts_processed = [], []
@@ -106,8 +129,8 @@ with gr.Blocks() as demo:
106
  return gr.HTML(value=html_code)
107
 
108
  text_column = gr.Textbox(placeholder="text", label="Text colum name to check (data must be non-nested, raw texts!)")
109
- batch_size = gr.Slider(0, 128, 64, step=8, label="Inference batch size (set this to smaller value if this space crashes.)")
110
- num_examples = gr.Number(1000, label="Number of first examples to check")
111
  gr_check_btn = gr.Button("Check Dataset")
112
  progress_bar = gr.Label(show_label=False)
113
  plot = gr.BarPlot()
 
1
+ import requests
2
+ from collections import Counter
3
+ from requests.adapters import HTTPAdapter, Retry
4
+
5
  import gradio as gr
6
+ import pandas as pd
7
  import polars as pl
8
+ import spaces
9
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
10
+ from huggingface_hub import PyTorchModelHubMixin
11
  import torch
 
12
  from torch import nn
13
  from transformers import AutoModel, AutoTokenizer, AutoConfig
14
+
15
+
16
+
17
+ session = requests.Session()
18
+ retries = Retry(total=5, backoff_factor=1, status_forcelist=[502, 503, 504])
19
+ session.mount('http://', HTTPAdapter(max_retries=retries))
20
 
21
 
22
  class QualityModel(nn.Module, PyTorchModelHubMixin):
 
73
 
74
 
75
  def run_quality_check(dataset, column, batch_size, num_examples):
76
+ # config = "default"
77
+ info_resp = session.get(f"https://datasets-server.huggingface.co/info?dataset={dataset}", timeout=3).json()
78
+ if "error" in info_resp:
79
+ yield "❌ " + info_resp["error"], gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
80
+ return
81
+ config = "default" if "default" in info_resp["dataset_info"] else next(iter(info_resp["dataset_info"]))
82
+ split = "train" if "train" in info_resp["dataset_info"][config]["splits"] else next(
83
+ iter(info_resp["dataset_info"][config]["splits"]))
84
+ try:
85
+ data = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{config}/{split}/0000.parquet", columns=[column])
86
+ except pl.exceptions.ComputeError:
87
+ try:
88
+ data = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{config}/partial-{split}/0000.parquet", columns=[column])
89
+ except Exception as error:
90
+ yield f"❌ {error}", gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
91
+ return
92
  texts = data[column].to_list()
93
  # batch_size = 100
94
  predictions, texts_processed = [], []
 
129
  return gr.HTML(value=html_code)
130
 
131
  text_column = gr.Textbox(placeholder="text", label="Text colum name to check (data must be non-nested, raw texts!)")
132
+ batch_size = gr.Slider(0, 128, 32, step=8, label="Inference batch size (set this to smaller value if this space crashes.)")
133
+ num_examples = gr.Number(500, label="Number of first examples to check")
134
  gr_check_btn = gr.Button("Check Dataset")
135
  progress_bar = gr.Label(show_label=False)
136
  plot = gr.BarPlot()