Francisco Aranda commited on
Commit
48548de
1 Parent(s): 36e4ada

suggest question values on record completed

Browse files
Files changed (1) hide show
  1. app.py +116 -39
app.py CHANGED
@@ -1,74 +1,151 @@
 
 
1
  from queue import Queue
 
 
2
 
3
  import argilla as rg
4
  import gradio as gr
 
5
 
6
  client = rg.Argilla()
7
 
8
- incoming_events = Queue()
9
 
10
- def check_incoming_events():
11
- """
12
- This function is called every 5 seconds to check if there are any incoming
13
- events and send data to update the JSON component.
14
- """
15
- events = []
16
- while not incoming_events.empty():
17
- events.append(incoming_events.get())
18
 
19
- return {"events": events}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
 
22
  with gr.Blocks() as demo:
23
  argilla_server = client.http_client.base_url
24
  gr.Markdown("## Argilla Events")
25
  gr.Markdown(f"This demo shows the incoming events from the [Argilla Server]({argilla_server}).")
26
- json_component = gr.JSON(label="Incoming argilla events:")
27
- gr.Timer(5, active=True).tick(check_incoming_events, outputs=json_component)
28
-
29
 
30
  server, _, _ = demo.launch(prevent_thread_lock=True, app_kwargs={"docs_url": "/docs"})
31
 
32
  # Set up the webhook listeners
33
  rg.set_webhook_server(server)
34
 
35
- # Delete all existing webhooks
36
  for webhook in client.webhooks:
37
- webhook.delete()
 
 
38
 
39
  # Create a webhook for record events
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- @rg.webhook_listener(
42
- events=["record.created", "record.updated", "record.completed"],
43
- raw_event=True # Using raw events until PR https://github.com/argilla-io/argilla/pull/5500 is merged
44
- )
45
- async def record_events(event:dict):
46
- print("Received event", event)
47
 
48
- incoming_events.put(event)
 
 
 
 
 
 
49
 
50
- # Create a webhook for dataset events
51
- @rg.webhook_listener(events=["dataset.created", "dataset.updated", "dataset.published"])
52
- async def dataset_events(type: str, dataset: rg.Dataset | None = None, **kwargs):
53
- print(f"Received event {type} for dataset {dataset.id}")
 
 
54
 
55
- incoming_events.put((type, dataset))
 
 
 
 
 
 
 
 
 
 
56
 
 
 
 
57
 
58
- # Create a webhook for response events
59
- @rg.webhook_listener(
60
- events=["response.created", "response.updated"],
61
- raw_event=True # Using raw events until PR https://github.com/argilla-io/argilla/pull/5500 is merged
62
- )
63
- async def response_events(event: dict):
64
- print("Received event", event)
 
 
65
 
66
- incoming_events.put(event)
67
 
68
- @rg.webhook_listener(events=["record.deleted", "dataset.deleted", "response.deleted"])
69
- async def deleted_events(type: str, data: dict, **kwargs):
70
- print(f"Received event {type} for resource {data}")
71
 
72
- incoming_events.put((type, data))
73
 
 
74
  demo.block_thread()
 
1
+ import json
2
+ import traceback
3
  from queue import Queue
4
+ from threading import Thread
5
+ from typing import List
6
 
7
  import argilla as rg
8
  import gradio as gr
9
+ from gradio_client import Client
10
 
11
  client = rg.Argilla()
12
 
13
+ completed_record_events = Queue()
14
 
 
 
 
 
 
 
 
 
15
 
16
+ def build_dataset(client: rg.Argilla) -> rg.Dataset:
17
+ settings = rg.Settings.from_hub("stanfordnlp/imdb")
18
+
19
+ settings.questions.add(rg.LabelQuestion(name="sentiment", labels=["negative", "positive"]))
20
+
21
+ dataset_name = "stanfordnlp_imdb"
22
+ dataset = client.datasets(dataset_name) or rg.Dataset.from_hub(
23
+ "stanfordnlp/imdb",
24
+ name=dataset_name,
25
+ settings=settings,
26
+ client=client,
27
+ split="train[:1000]"
28
+ )
29
+
30
+ return dataset
31
 
32
 
33
  with gr.Blocks() as demo:
34
  argilla_server = client.http_client.base_url
35
  gr.Markdown("## Argilla Events")
36
  gr.Markdown(f"This demo shows the incoming events from the [Argilla Server]({argilla_server}).")
37
+ gr.Markdown("### Record Events")
38
+ gr.Markdown("#### Records are processed in background and suggestions are added.")
 
39
 
40
  server, _, _ = demo.launch(prevent_thread_lock=True, app_kwargs={"docs_url": "/docs"})
41
 
42
  # Set up the webhook listeners
43
  rg.set_webhook_server(server)
44
 
 
45
  for webhook in client.webhooks:
46
+ webhook.enabled = False
47
+ webhook.update()
48
+
49
 
50
  # Create a webhook for record events
51
+ @rg.webhook_listener(events="record.completed")
52
+ async def record_events(record: rg.Record, type: str, **kwargs):
53
+ print("Received event", type)
54
+
55
+ completed_record_events.put(record)
56
+
57
+
58
+ dataset = build_dataset(client)
59
+
60
+
61
+ def add_record_suggestions_on_response_created():
62
+ print("Starting thread")
63
+
64
+ completed_records_filter = rg.Filter(("status", "==", "completed"))
65
+ pending_records_filter = rg.Filter(("status", "==", "pending"))
66
+
67
+ while True:
68
+ try:
69
+ record: rg.Record = completed_record_events.get()
70
+
71
+ if dataset.id != record.dataset.id:
72
+ continue
73
+
74
+ # Prepare predict data
75
+
76
+ field = dataset.settings.fields["text"]
77
+ question = dataset.settings.questions["sentiment"]
78
+
79
+ examples = list(
80
+ dataset.records(
81
+ query=rg.Query(filter=completed_records_filter),
82
+ limit=5,
83
+ )
84
+ )
85
+
86
+ some_pending_records = list(
87
+ dataset.records(
88
+ query=rg.Query(filter=pending_records_filter),
89
+ limit=5,
90
+ )
91
+ )
92
+
93
+ if not some_pending_records:
94
+ continue
95
+
96
+ some_pending_records = parse_pending_records(some_pending_records, field, question, examples)
97
+ dataset.records.log(some_pending_records)
98
+
99
+ except Exception:
100
+ print(traceback.format_exc())
101
+ continue
102
 
 
 
 
 
 
 
103
 
104
+ def parse_pending_records(
105
+ records: List[rg.Record],
106
+ field: rg.Field,
107
+ question,
108
+ example_records: List[rg.Record]
109
+ ) -> List[rg.Record]:
110
+ gradio_client = Client("davidberenstein1957/distilabel-argilla-labeller")
111
 
112
+ payload = {
113
+ "records": [record.to_dict() for record in records],
114
+ "fields": [field.serialize()],
115
+ "question": question.serialize(),
116
+ "example_records": [record.to_dict() for record in example_records],
117
+ }
118
 
119
+ print("Sending payload")
120
+ # response = gradio_client.predict(
121
+ # records=json.dumps(payload["records"]),
122
+ # example_records=json.dumps(payload["example_records"]),
123
+ # fields=json.dumps(payload["fields"]),
124
+ # question=json.dumps(payload["question"]),
125
+ # api_name="/predict"
126
+ # )
127
+ # print(response)
128
+ # response = json.loads(response)
129
+ # print("Response ", response)
130
 
131
+ response = {
132
+ "results": [{"value": "positive", "score": None, "agent": "mock"} for _ in records]
133
+ }
134
 
135
+ for record, suggestion in zip(records, response["results"]):
136
+ record.suggestions.add(
137
+ rg.Suggestion(
138
+ question_name=question.name,
139
+ value=suggestion["value"],
140
+ score=suggestion["score"],
141
+ agent=suggestion["agent"],
142
+ )
143
+ )
144
 
145
+ return records
146
 
 
 
 
147
 
148
+ thread = Thread(target=add_record_suggestions_on_response_created)
149
 
150
+ thread.start()
151
  demo.block_thread()