import json import traceback from queue import Queue from threading import Thread from typing import List import argilla as rg import gradio as gr from gradio_client import Client client = rg.Argilla() completed_record_events = Queue() def build_dataset(client: rg.Argilla) -> rg.Dataset: settings = rg.Settings.from_hub("stanfordnlp/imdb") settings.questions.add( rg.LabelQuestion(name="sentiment", labels=["negative", "positive"]) ) dataset_name = "stanfordnlp_imdb" dataset = client.datasets(dataset_name) or rg.Dataset.from_hub( "stanfordnlp/imdb", name=dataset_name, settings=settings, client=client, split="train[:1000]", ) return dataset with gr.Blocks() as demo: argilla_server = client.http_client.base_url gr.Markdown("## Argilla Events") gr.Markdown( f"This demo shows the incoming events from the [Argilla Server]({argilla_server})." ) gr.Markdown("### Record Events") gr.Markdown("#### Records are processed in background and suggestions are added.") server, _, _ = demo.launch(prevent_thread_lock=True, app_kwargs={"docs_url": "/docs"}) # Set up the webhook listeners rg.set_webhook_server(server) for webhook in client.webhooks: webhook.enabled = False webhook.update() # Create a webhook for record events @rg.webhook_listener(events="record.completed") async def record_events(record: rg.Record, type: str, **kwargs): print("Received event", type) completed_record_events.put(record) dataset = build_dataset(client) def add_record_suggestions_on_response_created(): print("Starting thread") completed_records_filter = rg.Filter(("status", "==", "completed")) pending_records_filter = rg.Filter(("status", "==", "pending")) while True: try: record: rg.Record = completed_record_events.get() if dataset.id != record.dataset.id: continue # Prepare predict data field = dataset.settings.fields["text"] question = dataset.settings.questions["sentiment"] examples = list( dataset.records( query=rg.Query(filter=completed_records_filter), limit=5, ) ) some_pending_records = list( dataset.records( query=rg.Query(filter=pending_records_filter), limit=5, ) ) if not some_pending_records: continue some_pending_records = parse_pending_records( some_pending_records, field, question, examples ) dataset.records.log(some_pending_records) except Exception: print(traceback.format_exc()) continue def parse_pending_records( records: List[rg.Record], field: rg.Field, question, example_records: List[rg.Record], ) -> List[rg.Record]: try: gradio_client = Client("davidberenstein1957/distilabel-argilla-labeller") payload = { "records": [record.to_dict() for record in records], "fields": [field.serialize()], "question": question.serialize(), "example_records": [record.to_dict() for record in example_records], "api_name": "/predict", } response = gradio_client.predict(**payload) response = json.loads(response) if isinstance(response, str) else response for record, suggestion in zip(records, response["results"]): record.suggestions.add( rg.Suggestion( question_name=question.name, value=suggestion["value"], score=suggestion["score"], agent=suggestion["agent"], ) ) except Exception: print(traceback.format_exc()) return records thread = Thread(target=add_record_suggestions_on_response_created) thread.start() demo.block_thread()