File size: 2,688 Bytes
a852b26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import argilla as rg
import gradio as gr
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import os
from typing import List
from collections import defaultdict

client = rg.Argilla(
    api_url=os.getenv("ARGILLA_FW2_URL"),
    api_key=os.getenv("ARGILLA_FW2_KEY")
)

def get_stats(dataset_idx, username):
    # dataset = client.datasets(dataset)
    dataset = datasets[dataset_idx]
    user_id = client.users(username).id

    user_label_counts = defaultdict(int)
    all_label_counts = defaultdict(int)

    for record in dataset.records:
        for response in record.responses["educational_value"]:
            label = response.value
            if response.user_id == user_id:
                user_label_counts[label] += 1
                all_label_counts[label] += 1
            else:
                all_label_counts[label] += 1
            
    return user_label_counts, all_label_counts


def build_plot(user_label_counts, all_label_counts):

    labels = ['None', 'Minimal', 'Basic', 'Good', 'Excellent', '❗ Problematic Content ❗']
    user_counts = [user_label_counts[label] for label in labels]
    overall_counts = [all_label_counts[label] for label in labels]

    fig = make_subplots(rows=1, cols=2, specs=[[{'type':'domain'}, {'type': 'domain'}]], subplot_titles=['My Label Usage', 'Team Label Usage'])

    fig.add_trace(go.Pie(labels=labels, values=user_counts, name="User Label Counts"),1,1)

    fig.add_trace(go.Pie(labels=labels,values=overall_counts, name="Overall Label Counts"),1,2)

    fig.update_layout(
        title="User vs Overall Label Counts",
        barmode="group"
    )

    return fig


def update_dashboard(dataset_idx, username):
    user_label_counts, all_label_counts = get_stats(dataset_idx, username)
    plot = build_plot(user_label_counts, all_label_counts)

    return plot



with gr.Blocks() as demo:
    gr.Markdown("# How do my annotations compare to my team's?")

    with gr.Row():
        datasets = client.datasets.list()
        dataset_choices = [(dataset.name, idx) for idx, dataset in enumerate(datasets)]
        datasets_dropdown = gr.Dropdown(
            choices=dataset_choices,
            label="Select your dataset",
            value=0,
            visible=True
        )

        search_box = gr.Textbox(type="text", label="Enter your username:")
    with gr.Row():
        search_button = gr.Button("Search πŸ”Ž")

    with gr.Row():
        plot_output = gr.Plot(label="Team and user annotations")

    search_button.click(
        fn=update_dashboard,
        inputs=[datasets_dropdown,search_box],
        outputs=[plot_output]
    )


if __name__ == "__main__":
    demo.launch()