|
|
|
import os |
|
from typing import List |
|
|
|
import argilla as rg |
|
import gradio as gr |
|
import pandas as pd |
|
import plotly.graph_objects as go |
|
|
|
client = rg.Argilla( |
|
api_url=os.getenv("ARGILLA_API_URL"), |
|
api_key=os.getenv("ARGILLA_API_KEY"), |
|
) |
|
|
|
|
|
def get_progress(dataset: rg.Dataset) -> dict: |
|
dataset_progress = dataset.progress(with_users_distribution=True) |
|
|
|
total, completed = dataset_progress["total"], dataset_progress["completed"] |
|
progress = (completed / total) * 100 if total > 0 else 0 |
|
|
|
return { |
|
"min_submitted": int(dataset.settings.distribution.min_submitted), |
|
"total": total, |
|
"annotated": completed, |
|
"progress": progress, |
|
"users": { |
|
username: user_progress["completed"].get("submitted") |
|
+ user_progress["pending"].get("submitted") |
|
for username, user_progress in dataset_progress["users"].items() |
|
}, |
|
} |
|
|
|
|
|
def create_progress_bar(progress): |
|
top_labels = ["Completed", "Pending"] |
|
|
|
colors = ["rgba(38, 24, 74, 0.8)", "rgba(190, 192, 213, 1)"] |
|
|
|
n_labelled_records = sum(progress["users"].values()) |
|
|
|
n_target_records = 30000 |
|
x_data = [[n_labelled_records, n_target_records - n_labelled_records]] |
|
|
|
y_data = ["Progress"] |
|
|
|
fig = go.Figure() |
|
|
|
for i in range(0, len(x_data[0])): |
|
for xd, yd in zip(x_data, y_data): |
|
fig.add_trace( |
|
go.Bar( |
|
x=[xd[i]], |
|
y=[yd], |
|
orientation="h", |
|
marker=dict( |
|
color=colors[i], line=dict(color="rgb(248, 248, 249)", width=1) |
|
), |
|
hoverinfo="text", |
|
hovertext=f"{top_labels[i]} records: {xd[i]}", |
|
) |
|
) |
|
|
|
fig.update_layout( |
|
xaxis=dict( |
|
showgrid=False, |
|
showline=False, |
|
showticklabels=False, |
|
zeroline=False, |
|
domain=[0.15, 1], |
|
), |
|
yaxis=dict( |
|
showgrid=False, |
|
showline=False, |
|
showticklabels=False, |
|
zeroline=False, |
|
domain=[0.15, 0.5], |
|
), |
|
barmode="stack", |
|
paper_bgcolor="rgb(248, 248, 255)", |
|
plot_bgcolor="rgb(248, 248, 255)", |
|
margin=dict(l=120, r=10, t=140, b=80), |
|
showlegend=False, |
|
) |
|
|
|
annotations = [] |
|
|
|
for yd, xd in zip(y_data, x_data): |
|
|
|
annotations.append( |
|
dict( |
|
xref="paper", |
|
yref="y", |
|
x=0.14, |
|
y=yd, |
|
xanchor="right", |
|
text=str(yd), |
|
font=dict(family="Arial", size=14, color="rgb(67, 67, 67)"), |
|
showarrow=False, |
|
align="right", |
|
) |
|
) |
|
|
|
if xd[0] > 0: |
|
annotations.append( |
|
dict( |
|
xref="x", |
|
yref="y", |
|
x=xd[0] / 2, |
|
y=yd, |
|
text=str(xd[0]), |
|
font=dict(family="Arial", size=14, color="rgb(248, 248, 255)"), |
|
showarrow=False, |
|
) |
|
) |
|
space = xd[0] |
|
for i in range(1, len(xd)): |
|
if xd[i] > 0: |
|
|
|
annotations.append( |
|
dict( |
|
xref="x", |
|
yref="y", |
|
x=space + (xd[i] / 2), |
|
y=yd, |
|
text=str(xd[i]), |
|
font=dict(family="Arial", size=14, color="rgb(248, 248, 255)"), |
|
showarrow=False, |
|
) |
|
) |
|
space += xd[i] |
|
|
|
fig.update_layout(annotations=annotations, height=80) |
|
return fig |
|
|
|
|
|
def create_piechart(user_annotations): |
|
sorted_users = sorted(user_annotations.items(), key=lambda x: x[1], reverse=True) |
|
|
|
labels, values = [], [] |
|
|
|
for user, contribution in sorted_users: |
|
labels.append(user) |
|
values.append(contribution) |
|
|
|
fig = go.Figure(go.Pie(labels=labels, values=values)) |
|
|
|
fig.update_layout( |
|
title_text="User contributions", |
|
height=500, |
|
margin=dict(l=10, r=10, t=50, b=10), |
|
template="ggplot2", |
|
) |
|
fig.update_traces(textposition="inside", textinfo="percent+label") |
|
|
|
return fig |
|
|
|
|
|
def get_datasets(client: rg.Argilla) -> List[rg.Dataset]: |
|
return client.datasets.list() |
|
|
|
|
|
datasets = get_datasets(client) |
|
|
|
from typing import Optional |
|
|
|
|
|
def update_dashboard(dataset_idx: Optional[int] = None): |
|
if dataset_idx is None: |
|
return [None, None, None] |
|
|
|
dataset = datasets[dataset_idx] |
|
progress = get_progress(dataset) |
|
|
|
progress_bar = create_progress_bar(progress) |
|
piechart = create_piechart(progress["users"]) |
|
|
|
leaderboard_df = pd.DataFrame( |
|
list(progress["users"].items()), columns=["User", "Submitted records"] |
|
) |
|
|
|
leaderboard_df = leaderboard_df.sort_values( |
|
"Submitted records", ascending=False |
|
).reset_index(drop=True) |
|
|
|
return progress_bar, piechart, leaderboard_df |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Argilla Progress Dashboard") |
|
|
|
dataset_choices = [(dataset.name, idx) for idx, dataset in enumerate(datasets)] |
|
datasets_dropdown = gr.Dropdown( |
|
choices=dataset_choices, label="Select your dataset", value=0, visible=True |
|
) |
|
|
|
def set_selected_dataset(dataset_idx) -> None: |
|
global selected_dataset |
|
|
|
dataset = datasets[dataset_idx] |
|
selected_dataset = dataset |
|
|
|
with gr.Row(): |
|
progress_bar_output = gr.Plot(label="Overall Progress") |
|
|
|
gr.Markdown("## Contributor Leaderboard") |
|
|
|
with gr.Row(): |
|
leaderboard_output = gr.Dataframe(headers=["User", "Submitted records"]) |
|
piechart_output = gr.Plot(label="User contributions") |
|
|
|
demo.load( |
|
update_dashboard, |
|
inputs=[datasets_dropdown], |
|
outputs=[progress_bar_output, piechart_output, leaderboard_output], |
|
) |
|
|
|
demo.load( |
|
update_dashboard, |
|
inputs=[datasets_dropdown], |
|
outputs=[progress_bar_output, piechart_output, leaderboard_output], |
|
|
|
) |
|
|
|
datasets_dropdown.change( |
|
update_dashboard, |
|
inputs=[datasets_dropdown], |
|
outputs=[progress_bar_output, piechart_output, leaderboard_output], |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|