lhoestq HF staff commited on
Commit
b4c506d
1 Parent(s): c83a2e3

implement revrite preview

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. app.py +137 -19
  3. requirements.txt +1 -0
  4. utils.py +60 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __*
app.py CHANGED
@@ -1,15 +1,35 @@
 
 
1
  from itertools import count, islice
2
- from typing import Any, Iterable
 
 
3
 
4
  import gradio as gr
 
5
  import pandas as pd
6
  import requests
 
7
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
 
8
 
 
 
 
 
 
9
 
10
  session = requests.Session()
11
  empty_dataframe = pd.DataFrame({"1": [], "2": [], "3": []})
12
- NUM_ROWS_PREVIEW = 5
 
 
 
 
 
 
 
 
13
 
14
 
15
  with gr.Blocks() as demo:
@@ -27,15 +47,17 @@ with gr.Blocks() as demo:
27
  subset_dropdown = gr.Dropdown(info="Subset", show_label=False, visible=False)
28
  split_dropdown = gr.Dropdown(info="Split", show_label=False, visible=False)
29
 
30
- input_query = gr.Textbox(label="Enter the adjustment or transformation to apply to the dataset:")
31
- rewrite_button = gr.Button("ReWrite Dataset", variant="primary")
32
-
33
  gr.Markdown("### Input")
34
- input_preview = gr.DataFrame(interactive=False, wrap=True)
 
35
 
36
- gr.Markdown("### Output")
 
 
 
 
37
  output_preview = gr.DataFrame(interactive=False, wrap=True)
38
- save_button = gr.Button("Save ReWriten Dataset", interactive=False)
39
 
40
 
41
  ############
@@ -56,6 +78,98 @@ with gr.Blocks() as demo:
56
  yield row_item["row"]
57
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  ############
60
  #
61
  # Events
@@ -78,9 +192,11 @@ with gr.Blocks() as demo:
78
  subset = default_subset if default_subset in subsets else subsets[0]
79
  splits: list[str] = info_resp["dataset_info"][subset]["splits"]
80
  split = default_split if default_split in splits else splits[0]
 
81
  return subset, split, {
82
  subset_dropdown: gr.Dropdown(value=subset, choices=subsets, visible=len(subsets) > 1),
83
  split_dropdown: gr.Dropdown(value=split, choices=splits, visible=len(splits) > 1),
 
84
  }
85
 
86
 
@@ -88,32 +204,34 @@ with gr.Blocks() as demo:
88
  subset, split, output = _resolve_dataset_selection(dataset, default_subset=default_subset, default_split=default_split)
89
  if subset is None or split is None:
90
  return output
 
91
  return {
92
- input_preview: pd.DataFrame(islice(({
93
- k: str(v) for k, v in row.items()}
94
- for row in stream_rows(dataset, subset, split, batch_size=NUM_ROWS_PREVIEW)
95
- ), NUM_ROWS_PREVIEW)),
96
  **output
97
  }
98
 
99
 
100
- @dataset_search.change(inputs=[dataset_search], outputs=[input_preview, subset_dropdown, split_dropdown])
101
  def show_input_from_dataset_search(dataset: str) -> dict:
102
  return _show_input_preview(dataset, default_subset="default", default_split="train")
103
 
104
- @subset_dropdown.change(inputs=[dataset_search, subset_dropdown], outputs=[input_preview, subset_dropdown, split_dropdown])
105
  def show_input_from_subset_dropdown(dataset: str, subset: str) -> dict:
106
  return _show_input_preview(dataset, default_subset=subset, default_split="train")
107
 
108
- @split_dropdown.change(inputs=[dataset_search, subset_dropdown, split_dropdown], outputs=[input_preview, subset_dropdown, split_dropdown])
109
  def show_input_from_split_dropdown(dataset: str, subset: str, split: str) -> dict:
110
  return _show_input_preview(dataset, default_subset=subset, default_split=split)
111
 
112
 
113
- @rewrite_button.click(inputs=[dataset_search, subset_dropdown, split_dropdown, input_preview], outputs=[output_preview])
114
- def rewrite(dataset: str, subset: str, split: str, input_preview_df: pd.DataFrame) -> dict:
115
- # TODO: implement
116
- return {output_preview: pd.DataFrame([{"TODO": ["implement"]}])}
 
 
 
117
 
118
 
119
  demo.launch()
 
1
+ import json
2
+ import time
3
  from itertools import count, islice
4
+ from multiprocessing.pool import ThreadPool
5
+ from queue import Queue, Empty
6
+ from typing import Any, Callable, Iterable, Iterator, TypeVar
7
 
8
  import gradio as gr
9
+ import ijson
10
  import pandas as pd
11
  import requests
12
+ from datasets import Features, Value, Sequence
13
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
14
+ from huggingface_hub import InferenceClient
15
 
16
+ from utils import StringIteratorIO
17
+
18
+
19
+ model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
20
+ client = InferenceClient(model_id)
21
 
22
  session = requests.Session()
23
  empty_dataframe = pd.DataFrame({"1": [], "2": [], "3": []})
24
+
25
+ NUM_ROWS_PREVIEW = 3
26
+ REWRITE_DATASET = (
27
+ "A Machine Learning practitioner is looking for a dataset similar to '{dataset}' but slightly different. "
28
+ "They want you to rewrite the dataset and apply this transformation: {prompt}."
29
+ "The first rows of the dataset are below in JSON format (one JSON object per line):\n\n{rows}\n\n"
30
+ "Rewrite those rows from the '{dataset}' dataset using the same format (one JSON object per line). "
31
+ "Try to keep some of the text or meaning intact, and apply the requested transformation '{prompt}'."
32
+ )
33
 
34
 
35
  with gr.Blocks() as demo:
 
47
  subset_dropdown = gr.Dropdown(info="Subset", show_label=False, visible=False)
48
  split_dropdown = gr.Dropdown(info="Split", show_label=False, visible=False)
49
 
 
 
 
50
  gr.Markdown("### Input")
51
+ input_preview = gr.DataFrame(visible=False)
52
+ pretty_input_preview = gr.DataFrame(interactive=False, wrap=True)
53
 
54
+ gr.Markdown("### ReWrite")
55
+ input_prompt = gr.Textbox(label="Enter the adjustment or transformation to apply to the dataset:")
56
+ with gr.Accordion("Modify Format", open=False):
57
+ output_format = gr.Textbox(interactive=True, show_label=False, container=False)
58
+ rewrite_button = gr.Button("ReWrite Dataset", variant="primary")
59
  output_preview = gr.DataFrame(interactive=False, wrap=True)
60
+ save_button = gr.Button("ReWrite Full Dataset", interactive=False)
61
 
62
 
63
  ############
 
78
  yield row_item["row"]
79
 
80
 
81
+ T = TypeVar("T")
82
+
83
+
84
+ def batched(it: Iterable[T], n: int) -> Iterator[list[T]]:
85
+ it = iter(it)
86
+ while batch := list(islice(it, n)):
87
+ yield batch
88
+
89
+
90
+ def stream_reponse(messages: list[dict[str: str]], response_format=None) -> Iterator[str]:
91
+ for _ in range(3):
92
+ message = None
93
+ try:
94
+ for message in client.chat_completion(
95
+ messages=messages,
96
+ max_tokens=5000,
97
+ stream=True,
98
+ top_p=0.8,
99
+ seed=42,
100
+ response_format=response_format
101
+ ):
102
+ yield message.choices[0].delta.content
103
+ except requests.exceptions.ConnectionError as e:
104
+ if message:
105
+ raise
106
+ print(e + "\n\nRetrying in 1sec")
107
+ time.sleep(1)
108
+ continue
109
+ break
110
+
111
+
112
+ def stream_rewrite_dataset_row_by_row(dataset: str, rows: list[dict[str, str]], prompt: str, format: str) -> Iterator[dict[str, str]]:
113
+ prompt = prompt[:1000] if prompt.strip() else ""
114
+ messages = [{"role": "user", "content": REWRITE_DATASET.format(
115
+ dataset=dataset,
116
+ rows=json.dumps({"data": rows}),
117
+ prompt=prompt,
118
+ )}]
119
+ response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "maxItems": len(rows), "minItems": len(rows), "items": format}}, "required": ["data"]}}
120
+ print("go")
121
+ yield from islice(ijson.items(StringIteratorIO(stream_reponse(messages, response_format=response_format)), "data.item", buf_size=4), len(rows))
122
+ print("done")
123
+
124
+
125
+ def _write_generator_to_queue(queue: Queue, func: Callable[..., Iterable], kwargs: dict) -> None:
126
+ for i, result in enumerate(func(**kwargs)):
127
+ queue.put(result)
128
+ return None
129
+
130
+
131
+ def iflatmap_unordered(
132
+ func: Callable[..., Iterable[T]],
133
+ *,
134
+ kwargs_iterable: Iterable[dict],
135
+ ) -> Iterable[T]:
136
+ queue = Queue()
137
+ with ThreadPool() as pool:
138
+ async_results = [pool.apply_async(_write_generator_to_queue, (queue, func, kwargs)) for kwargs in kwargs_iterable]
139
+ try:
140
+ while True:
141
+ try:
142
+ yield queue.get(timeout=0.05)
143
+ except Empty:
144
+ if all(async_result.ready() for async_result in async_results) and queue.empty():
145
+ break
146
+ finally: # in case there's an error to raise
147
+ [async_result.get(timeout=0.05) for async_result in async_results]
148
+
149
+
150
+ def features_to_format(features: Features) -> dict:
151
+ def feature_to_format(feature):
152
+ if isinstance(feature, Value):
153
+ if "int" in feature.dtype:
154
+ return {"type": "integer"}
155
+ elif "float" in feature.dtype:
156
+ return {"type": "number"}
157
+ else:
158
+ return {"type": "string"}
159
+ elif isinstance(feature, list):
160
+ return {"type": "array", "items": feature_to_format(feature[0])}
161
+ elif isinstance(feature, dict):
162
+ return {"properties": {k: feature_to_format(v) for k, v in feature.items()}, "required": list(feature)}
163
+ elif isinstance(feature, Sequence):
164
+ if isinstance(feature.feature, dict):
165
+ return {"properties": {k: {"type": "array", "items": v } for k, v in feature_to_format(feature.feature).items()}, "required": list(feature)}
166
+ else:
167
+ return {"type": "array", "items": feature_to_format(feature.feature)}
168
+ else:
169
+ return {"type": "string"}
170
+ return feature_to_format(features)
171
+
172
+
173
  ############
174
  #
175
  # Events
 
192
  subset = default_subset if default_subset in subsets else subsets[0]
193
  splits: list[str] = info_resp["dataset_info"][subset]["splits"]
194
  split = default_split if default_split in splits else splits[0]
195
+ json_format = json.dumps(features_to_format(Features.from_dict(info_resp["dataset_info"][subset]["features"])), indent=2)
196
  return subset, split, {
197
  subset_dropdown: gr.Dropdown(value=subset, choices=subsets, visible=len(subsets) > 1),
198
  split_dropdown: gr.Dropdown(value=split, choices=splits, visible=len(splits) > 1),
199
+ output_format: gr.Textbox(json_format, lines=json_format.count("\n") + 1)
200
  }
201
 
202
 
 
204
  subset, split, output = _resolve_dataset_selection(dataset, default_subset=default_subset, default_split=default_split)
205
  if subset is None or split is None:
206
  return output
207
+ rows = list(islice((stream_rows(dataset, subset, split, batch_size=NUM_ROWS_PREVIEW)), NUM_ROWS_PREVIEW))
208
  return {
209
+ input_preview: pd.DataFrame(rows),
210
+ pretty_input_preview: pd.DataFrame([{k: str(v) for k, v in row.items()} for row in rows]),
 
 
211
  **output
212
  }
213
 
214
 
215
+ @dataset_search.change(inputs=[dataset_search], outputs=[input_preview, pretty_input_preview, subset_dropdown, split_dropdown, output_format])
216
  def show_input_from_dataset_search(dataset: str) -> dict:
217
  return _show_input_preview(dataset, default_subset="default", default_split="train")
218
 
219
+ @subset_dropdown.change(inputs=[dataset_search, subset_dropdown], outputs=[input_preview, pretty_input_preview, subset_dropdown, split_dropdown, output_format])
220
  def show_input_from_subset_dropdown(dataset: str, subset: str) -> dict:
221
  return _show_input_preview(dataset, default_subset=subset, default_split="train")
222
 
223
+ @split_dropdown.change(inputs=[dataset_search, subset_dropdown, split_dropdown], outputs=[input_preview, pretty_input_preview, subset_dropdown, split_dropdown, output_format])
224
  def show_input_from_split_dropdown(dataset: str, subset: str, split: str) -> dict:
225
  return _show_input_preview(dataset, default_subset=subset, default_split=split)
226
 
227
 
228
+ @rewrite_button.click(inputs=[dataset_search, subset_dropdown, split_dropdown, input_preview, input_prompt, output_format], outputs=[output_preview])
229
+ def rewrite(dataset: str, subset: str, split: str, input_preview_df: pd.DataFrame, prompt: str, json_format: str) -> Iterator[pd.DataFrame]:
230
+ rows = input_preview_df.to_dict(orient="records")
231
+ output_rows = []
232
+ for row in stream_rewrite_dataset_row_by_row(dataset=dataset, rows=rows, prompt=prompt, format=json.loads(json_format)):
233
+ output_rows.append(row)
234
+ yield pd.DataFrame(output_rows)
235
 
236
 
237
  demo.launch()
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  requests
2
  pandas
3
  gradio_huggingfacehub_search
 
 
1
  requests
2
  pandas
3
  gradio_huggingfacehub_search
4
+ datasets
utils.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import logging
3
+
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ class StringIteratorIO(io.TextIOBase):
9
+ """From: https://stackoverflow.com/a/12604375"""
10
+
11
+ def __init__(self, iter):
12
+ self._iter = iter
13
+ self._left = ''
14
+
15
+ def readable(self):
16
+ return True
17
+
18
+ def _read1(self, n=None):
19
+ while not self._left:
20
+ try:
21
+ self._left = next(self._iter)
22
+ except StopIteration:
23
+ break
24
+ ret = self._left[:n]
25
+ self._left = self._left[len(ret):]
26
+ return ret
27
+
28
+ def read(self, n=None):
29
+ buf = []
30
+ if n is None or n < 0:
31
+ while True:
32
+ m = self._read1()
33
+ if not m:
34
+ break
35
+ buf.append(m)
36
+ else:
37
+ while n > 0:
38
+ m = self._read1(n)
39
+ if not m:
40
+ break
41
+ n -= len(m)
42
+ buf.append(m)
43
+ return ''.join(buf)
44
+
45
+ def readline(self):
46
+ buf = []
47
+ while True:
48
+ i = self._left.find('\n')
49
+ if i == -1:
50
+ buf.append(self._left)
51
+ try:
52
+ self._left = next(self._iter)
53
+ except StopIteration:
54
+ self._left = ''
55
+ break
56
+ else:
57
+ buf.append(self._left[:i+1])
58
+ self._left = self._left[i+1:]
59
+ break
60
+ return ''.join(buf)