Spaces:
Sleeping
Sleeping
import json | |
import traceback | |
from queue import Queue | |
from threading import Thread | |
from typing import List | |
import argilla as rg | |
import gradio as gr | |
from gradio_client import Client | |
client = rg.Argilla() | |
completed_record_events = Queue() | |
def build_dataset(client: rg.Argilla) -> rg.Dataset: | |
settings = rg.Settings.from_hub("stanfordnlp/imdb") | |
settings.questions.add( | |
rg.LabelQuestion(name="sentiment", labels=["negative", "positive"]) | |
) | |
dataset_name = "stanfordnlp_imdb" | |
dataset = client.datasets(dataset_name) or rg.Dataset.from_hub( | |
"stanfordnlp/imdb", | |
name=dataset_name, | |
settings=settings, | |
client=client, | |
split="train[:1000]", | |
) | |
return dataset | |
with gr.Blocks() as demo: | |
argilla_server = client.http_client.base_url | |
gr.Markdown("## Argilla Events") | |
gr.Markdown( | |
f"This demo shows the incoming events from the [Argilla Server]({argilla_server})." | |
) | |
gr.Markdown("### Record Events") | |
gr.Markdown("#### Records are processed in background and suggestions are added.") | |
server, _, _ = demo.launch(prevent_thread_lock=True, app_kwargs={"docs_url": "/docs"}) | |
# Set up the webhook listeners | |
rg.set_webhook_server(server) | |
for webhook in client.webhooks: | |
webhook.enabled = False | |
webhook.update() | |
# Create a webhook for record events | |
async def record_events(record: rg.Record, type: str, **kwargs): | |
print("Received event", type) | |
completed_record_events.put(record) | |
dataset = build_dataset(client) | |
def add_record_suggestions_on_response_created(): | |
print("Starting thread") | |
completed_records_filter = rg.Filter(("status", "==", "completed")) | |
pending_records_filter = rg.Filter(("status", "==", "pending")) | |
while True: | |
try: | |
record: rg.Record = completed_record_events.get() | |
if dataset.id != record.dataset.id: | |
continue | |
# Prepare predict data | |
field = dataset.settings.fields["text"] | |
question = dataset.settings.questions["sentiment"] | |
examples = list( | |
dataset.records( | |
query=rg.Query(filter=completed_records_filter), | |
limit=5, | |
) | |
) | |
some_pending_records = list( | |
dataset.records( | |
query=rg.Query(filter=pending_records_filter), | |
limit=5, | |
) | |
) | |
if not some_pending_records: | |
continue | |
some_pending_records = parse_pending_records( | |
some_pending_records, field, question, examples | |
) | |
dataset.records.log(some_pending_records) | |
except Exception: | |
print(traceback.format_exc()) | |
continue | |
def parse_pending_records( | |
records: List[rg.Record], | |
field: rg.Field, | |
question, | |
example_records: List[rg.Record], | |
) -> List[rg.Record]: | |
try: | |
gradio_client = Client("davidberenstein1957/distilabel-argilla-labeller") | |
payload = { | |
"records": [record.to_dict() for record in records], | |
"fields": [field.serialize()], | |
"question": question.serialize(), | |
"example_records": [record.to_dict() for record in example_records], | |
"api_name": "/predict", | |
} | |
response = gradio_client.predict(**payload) | |
response = json.loads(response) if isinstance(response, str) else response | |
for record, suggestion in zip(records, response["results"]): | |
record.suggestions.add( | |
rg.Suggestion( | |
question_name=question.name, | |
value=suggestion["value"], | |
score=suggestion["score"], | |
agent=suggestion["agent"], | |
) | |
) | |
except Exception: | |
print(traceback.format_exc()) | |
return records | |
thread = Thread(target=add_record_suggestions_on_response_created) | |
thread.start() | |
demo.block_thread() | |