|
import argilla as rg |
|
import gradio as gr |
|
from plotly.subplots import make_subplots |
|
import plotly.graph_objects as go |
|
import os |
|
from typing import List |
|
from collections import defaultdict |
|
|
|
client = rg.Argilla( |
|
api_url=os.getenv("ARGILLA_FW2_URL"), |
|
api_key=os.getenv("ARGILLA_FW2_KEY") |
|
) |
|
|
|
def get_stats(dataset_idx, username): |
|
|
|
dataset = datasets[dataset_idx] |
|
user_id = client.users(username).id |
|
|
|
user_label_counts = defaultdict(int) |
|
all_label_counts = defaultdict(int) |
|
|
|
for record in dataset.records: |
|
for response in record.responses["educational_value"]: |
|
label = response.value |
|
if response.user_id == user_id: |
|
user_label_counts[label] += 1 |
|
all_label_counts[label] += 1 |
|
else: |
|
all_label_counts[label] += 1 |
|
|
|
return user_label_counts, all_label_counts |
|
|
|
|
|
def build_plot(user_label_counts, all_label_counts): |
|
|
|
labels = ['None', 'Minimal', 'Basic', 'Good', 'Excellent', 'β Problematic Content β'] |
|
user_counts = [user_label_counts[label] for label in labels] |
|
overall_counts = [all_label_counts[label] for label in labels] |
|
|
|
fig = make_subplots(rows=1, cols=2, specs=[[{'type':'domain'}, {'type': 'domain'}]], subplot_titles=['My Label Usage', 'Team Label Usage']) |
|
|
|
fig.add_trace(go.Pie(labels=labels, values=user_counts, name="User Label Counts"),1,1) |
|
|
|
fig.add_trace(go.Pie(labels=labels,values=overall_counts, name="Overall Label Counts"),1,2) |
|
|
|
fig.update_layout( |
|
title="User vs Overall Label Counts", |
|
barmode="group" |
|
) |
|
|
|
return fig |
|
|
|
|
|
def update_dashboard(dataset_idx, username): |
|
user_label_counts, all_label_counts = get_stats(dataset_idx, username) |
|
plot = build_plot(user_label_counts, all_label_counts) |
|
|
|
return plot |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# How do my annotations compare to my team's?") |
|
|
|
with gr.Row(): |
|
datasets = client.datasets.list() |
|
dataset_choices = [(dataset.name, idx) for idx, dataset in enumerate(datasets)] |
|
datasets_dropdown = gr.Dropdown( |
|
choices=dataset_choices, |
|
label="Select your dataset", |
|
value=0, |
|
visible=True |
|
) |
|
|
|
search_box = gr.Textbox(type="text", label="Enter your username:") |
|
with gr.Row(): |
|
search_button = gr.Button("Search π") |
|
|
|
with gr.Row(): |
|
plot_output = gr.Plot(label="Team and user annotations") |
|
|
|
search_button.click( |
|
fn=update_dashboard, |
|
inputs=[datasets_dropdown,search_box], |
|
outputs=[plot_output] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|
|
|