osanseviero HF staff commited on
Commit
578bab4
1 Parent(s): ebc9782

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+
4
+ import datasets
5
+ from datasets import load_dataset
6
+ from huggingface_hub import delete_repo
7
+
8
+ idx = 0
9
+ data_to_label = load_dataset("active-learning/to_label_samples")
10
+ imgs = data_to_label["train"]["image"]
11
+
12
+ def get_image():
13
+ global idx
14
+ new_img = imgs[idx]
15
+ idx += 1
16
+ return new_img
17
+
18
+ labeled_data = []
19
+
20
+ information = """# Active Learning Demo
21
+
22
+ This demo showcases Active Learning, which is great when labeling is expensive. In this demo, you will label images by choosing a digit (0-9).
23
+
24
+ How does this work?
25
+ * There is a large pool of unlabeled images
26
+ * A model is trained with the few labeled images
27
+ * We can then use the model to pick the images with the lowest confidence or with the lowest probability of corresponding to an image. These are the images for which the model is confused, so by improving them, the quality of the model can improve much more than queries for which the model was already doing well!
28
+ * In this UI, you will be provided a couple of images to label
29
+ * Once all the provided images are labeled, the model is retrained, and a new set of images is chosen!
30
+ """
31
+
32
+ webhook_info = """## Model Retraining
33
+
34
+ There are new labeled images. The model is retraining. Follow progress in [here](https://huggingface.co/spaces/active-learning/webhook).
35
+ """
36
+
37
+ with gr.Blocks() as demo:
38
+ gr.Markdown(information)
39
+
40
+ img_to_label = gr.Image(shape=[28,28], value=get_image())
41
+ label_dropdown = gr.Dropdown(choices=[0,1,2,3,4,5,6,7,8,9], interactive=True, value=0)
42
+ save_btn = gr.Button("Save label")
43
+ output_box = gr.Markdown(value=webhook_info, visible=False)
44
+ reload_btn = gr.Button("Reload", visible=False)
45
+
46
+ def save_data(img, label):
47
+ global labeled_data
48
+ global idx
49
+
50
+ labeled_data.append([img, label])
51
+
52
+ if len(imgs) == idx :
53
+ # Remove dataset of queries to label
54
+ # datasets library does not allow pushing an empty dataset, so as a
55
+ # workaround we just delete the repo
56
+ delete_repo(repo_id="active-learning/to_label_samples", repo_type="dataset")
57
+
58
+ # Save to dataset
59
+ labeled_dataset = load_dataset("active-learning/labeled_samples")["train"]
60
+ feature = datasets.Image(decode=False)
61
+ for img, label in labeled_data:
62
+ # Hack due to https://github.com/huggingface/datasets/issues/4796
63
+ labeled_dataset = labeled_dataset.add_item({
64
+ "image": feature.encode_example(Image.fromarray(img)),
65
+ "label": label
66
+ })
67
+ labeled_dataset.push_to_hub("active-learning/labeled_samples")
68
+ labeled_data = []
69
+ idx = 0
70
+ return {
71
+ img_to_label: gr.update(visible=False),
72
+ label_dropdown: gr.update(visible=False),
73
+ save_btn: gr.update(visible=False),
74
+ output_box: gr.update(visible=True),
75
+ reload_btn: gr.update(visible=True)
76
+ }
77
+ else:
78
+ return {
79
+ img_to_label: gr.update(value=get_image())
80
+ }
81
+
82
+ def reload_data():
83
+ global data_to_label
84
+ global imgs
85
+ data_to_label = load_dataset("active-learning/to_label_samples")
86
+ imgs = data_to_label["train"]["image"]
87
+ if len(imgs) == 0:
88
+ return
89
+ else:
90
+ return {
91
+ img_to_label: gr.update(visible=True),
92
+ label_dropdown: gr.update(visible=True),
93
+ save_btn: gr.update(visible=True),
94
+ output_box: gr.update(visible=False),
95
+ reload_btn: gr.update(visible=False)
96
+ }
97
+
98
+ save_btn.click(
99
+ save_data,
100
+ inputs=[img_to_label, label_dropdown],
101
+ outputs=[img_to_label, label_dropdown, save_btn, output_box, reload_btn]
102
+ )
103
+
104
+ reload_btn.click(
105
+ reload_data,
106
+ outputs=[img_to_label, label_dropdown, save_btn, output_box, reload_btn]
107
+ )
108
+
109
+ demo.launch(debug=True)