pratyushmaini commited on
Commit
3b4591b
·
1 Parent(s): b1029e7
Files changed (1) hide show
  1. app.py +46 -15
app.py CHANGED
@@ -1,18 +1,49 @@
 
1
  import gradio as gr
2
  from datasets import load_dataset
3
 
4
- def view_dataset(sample_size):
5
- dataset = load_dataset("locuslab/TOFU")
6
- data = dataset['train'].select(range(sample_size)).to_pandas()
7
- return data
8
-
9
- interface = gr.Interface(
10
- fn=view_dataset,
11
- inputs=gr.inputs.Slider(minimum=1, maximum=100, step=1, default=5, label="Number of Samples"),
12
- outputs="dataframe",
13
- title="TOFU Dataset Viewer",
14
- description="Interactive viewer for the TOFU dataset"
15
- )
16
-
17
- if __name__ == "__main__":
18
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import gradio as gr
3
  from datasets import load_dataset
4
 
5
+ tofu_ds = load_dataset('locuslab/TOFU')
6
+ ds = tofu_ds['full']
7
+
8
+ df = ds.to_pandas()
9
+ cols = list(df.columns)
10
+ df = df.reindex(columns=cols)
11
+
12
+ LINES_NUMBER = 20
13
+
14
+ def display_df():
15
+ df_images = df.head(LINES_NUMBER)
16
+ return df_images
17
+
18
+ def display_next(dataframe, end):
19
+ start = int(end or len(dataframe))
20
+ end = int(start) + int(LINES_NUMBER)
21
+ global df
22
+ if end >= len(df) - 1:
23
+ start = 0
24
+ end = LINES_NUMBER
25
+ df = df.sample(frac=1)
26
+ print(f"Shuffle")
27
+ df_images = df.iloc[start:end]
28
+ assert len(df_images) == LINES_NUMBER
29
+ return df_images, end
30
+
31
+ initial_dataframe = display_df()
32
+
33
+ # Gradio Blocks
34
+ with gr.Blocks() as demo:
35
+ gr.Markdown("<h1><center>TOFU Dataset Viewer</center></h1>")
36
+
37
+ with gr.Row():
38
+ num_end = gr.Number(visible=False)
39
+ b1 = gr.Button("Get Initial dataframe")
40
+ b2 = gr.Button("Next Rows")
41
+
42
+ with gr.Row():
43
+ out_dataframe = gr.Dataframe(initial_dataframe, wrap=True, interactive=False, datatype = ['str', 'str'])
44
+
45
+ b1.click(fn=display_df, outputs=out_dataframe, api_name="initial_dataframe")
46
+ b2.click(fn=display_next, inputs=[out_dataframe, num_end], outputs=[out_dataframe, num_end],
47
+ api_name="next_rows")
48
+
49
+ demo.launch(debug=True, show_error=True)