Spaces:
Sleeping
Sleeping
import time | |
from itertools import islice | |
from functools import partial | |
from typing import Iterable, Iterator, TypeVar | |
import gradio as gr | |
import requests.exceptions | |
from huggingface_hub import InferenceClient | |
model_id = "microsoft/Phi-3-mini-4k-instruct" | |
client = InferenceClient(model_id) | |
MAX_TOTAL_NB_ITEMS = 100 # almost infinite, don't judge me (actually it's because gradio needs a fixed number of components) | |
MAX_NB_ITEMS_PER_GENERATION_CALL = 10 | |
URL = "https://huggingface.co/spaces/infinite-dataset-hub/infinite-dataset-hub" | |
GENERATE_DATASET_NAMES_FOR_SEARCH_QUERY = ( | |
"A Machine Learning Practioner is looking for a dataset that matches '{search_query}'. " | |
f"Generate a list of {MAX_NB_ITEMS_PER_GENERATION_CALL} names of quality dataset that don't exist but sound plausible and would " | |
"be helpful. Feel free to reuse words from the query '{search_query}' to name the datasets. " | |
"Every dataset should be about '{search_query}' and have descriptive tags/keywords including the ML task name associated to the dataset (classification, regression, anomaly detection, etc.). Use the following format:\n1. DatasetName1 (tag1, tag2, tag3)\n1. DatasetName2 (tag1, tag2, tag3)" | |
) | |
GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS = ( | |
"A ML practitioner is looking for a dataset CSV after the query '{search_query}'. " | |
"Generate the first 5 rows of a plausible and quality CSV for the dataset '{dataset_name}'. " | |
"You can get inspiration from related keywords '{tags}' but most importantly the dataset should correspond to the query '{search_query}'. " | |
"Focus on quality text content and and use a 'label' or 'labels' column if it makes sense (invent labels, avoid reusing the keywords, be accurate while labelling texts). " | |
"Reply using a short description of the dataset with title **Dataset Description:** followed by the CSV content in a code block and with title **CSV Content Preview:**." | |
) | |
landing_page_query = "various datasets on many different subjects and topics, from classification to language modeling, from science to sport to finance to news" | |
landing_page_datasets_generated_text = """ | |
1. NewsEventsPredict (classification, media, trend) | |
2. FinancialForecast (economy, stocks, regression) | |
3. HealthMonitor (science, real-time, anomaly detection) | |
4. SportsAnalysis (classification, performance, player tracking) | |
5. SciLiteracyTools (language modeling, science literacy, text classification) | |
6. RetailSalesAnalyzer (consumer behavior, sales trend, segmentation) | |
7. SocialSentimentEcho (social media, emotion analysis, clustering) | |
8. NewsEventTracker (classification, public awareness, topical clustering) | |
9. HealthVitalSigns (anomaly detection, biometrics, prediction) | |
10. GameStockPredict (classification, finance, sports contingency) | |
""" | |
default_output = landing_page_datasets_generated_text.strip().split("\n") | |
assert len(default_output) == MAX_NB_ITEMS_PER_GENERATION_CALL | |
css = """ | |
a { | |
color: var(--body-text-color); | |
} | |
.datasetButton { | |
justify-content: start; | |
justify-content: left; | |
} | |
.tags { | |
font-size: var(--button-small-text-size); | |
color: var(--body-text-color-subdued); | |
} | |
.topButton { | |
justify-content: start; | |
justify-content: left; | |
text-align: left; | |
background: transparent; | |
box-shadow: none; | |
padding-bottom: 0; | |
} | |
.topButton::before { | |
content: url("data:image/svg+xml,%3Csvg style='color: rgb(209 213 219)' xmlns='http://www.w3.org/2000/svg' xmlns:xlink='http://www.w3.org/1999/xlink' aria-hidden='true' focusable='false' role='img' width='1em' height='1em' preserveAspectRatio='xMidYMid meet' viewBox='0 0 25 25'%3E%3Cellipse cx='12.5' cy='5' fill='currentColor' fill-opacity='0.25' rx='7.5' ry='2'%3E%3C/ellipse%3E%3Cpath d='M12.5 15C16.6421 15 20 14.1046 20 13V20C20 21.1046 16.6421 22 12.5 22C8.35786 22 5 21.1046 5 20V13C5 14.1046 8.35786 15 12.5 15Z' fill='currentColor' opacity='0.5'%3E%3C/path%3E%3Cpath d='M12.5 7C16.6421 7 20 6.10457 20 5V11.5C20 12.6046 16.6421 13.5 12.5 13.5C8.35786 13.5 5 12.6046 5 11.5V5C5 6.10457 8.35786 7 12.5 7Z' fill='currentColor' opacity='0.5'%3E%3C/path%3E%3Cpath d='M5.23628 12C5.08204 12.1598 5 12.8273 5 13C5 14.1046 8.35786 15 12.5 15C16.6421 15 20 14.1046 20 13C20 12.8273 19.918 12.1598 19.7637 12C18.9311 12.8626 15.9947 13.5 12.5 13.5C9.0053 13.5 6.06886 12.8626 5.23628 12Z' fill='currentColor'%3E%3C/path%3E%3C/svg%3E"); | |
margin-right: .25rem; | |
margin-left: -.125rem; | |
margin-top: .25rem; | |
} | |
.bottomButton { | |
justify-content: start; | |
justify-content: left; | |
text-align: left; | |
background: transparent; | |
box-shadow: none; | |
font-size: var(--button-small-text-size); | |
color: var(--body-text-color-subdued); | |
padding-top: 0; | |
align-items: baseline; | |
} | |
.bottomButton::before { | |
content: 'tags:'; | |
margin-right: .25rem; | |
} | |
.buttonsGroup { | |
background: transparent; | |
} | |
.buttonsGroup:hover { | |
background: var(--input-background-fill); | |
} | |
.buttonsGroup div { | |
background: transparent; | |
} | |
.insivibleButtonGroup { | |
display: none; | |
} | |
@keyframes placeHolderShimmer{ | |
0%{ | |
background-position: -468px 0 | |
} | |
100%{ | |
background-position: 468px 0 | |
} | |
} | |
.linear-background { | |
animation-duration: 1s; | |
animation-fill-mode: forwards; | |
animation-iteration-count: infinite; | |
animation-name: placeHolderShimmer; | |
animation-timing-function: linear; | |
background-image: linear-gradient(to right, var(--body-text-color-subdued) 8%, #dddddd11 18%, var(--body-text-color-subdued) 33%); | |
background-size: 1000px 104px; | |
color: transparent; | |
background-clip: text; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
generated_texts_state = gr.State((landing_page_datasets_generated_text,)) | |
with gr.Row(): | |
with gr.Column(scale=4, min_width=0): | |
pass | |
with gr.Column(scale=10): | |
gr.Markdown( | |
"# π€ Infinite Dataset Hub βΎοΈ\n\n" | |
"An endless catalog of datasets, created just for you.\n\n" | |
) | |
with gr.Column(scale=4, min_width=0): | |
pass | |
with gr.Column() as search_page: | |
with gr.Row(): | |
with gr.Column(scale=4, min_width=0): | |
pass | |
with gr.Column(scale=10): | |
with gr.Row(): | |
search_bar = gr.Textbox(max_lines=1, placeholder="Search datasets, get infinite results", show_label=False, container=False, scale=9) | |
search_button = gr.Button("π", variant="primary", scale=1) | |
with gr.Column(scale=4, min_width=0): | |
pass | |
with gr.Row(): | |
with gr.Column(scale=4, min_width=0): | |
pass | |
with gr.Column(scale=10): | |
button_groups: list[gr.Group] = [] | |
buttons: list[gr.Button] = [] | |
for i in range(MAX_TOTAL_NB_ITEMS): | |
if i < len(default_output): | |
line = default_output[i] | |
dataset_name, tags = line.split(".", 1)[1].strip(" )").split(" (", 1) | |
group_classes = "buttonsGroup" | |
dataset_name_classes = "topButton" | |
tags_classes = "bottomButton" | |
else: | |
dataset_name, tags = "β¬β¬β¬β¬β¬β¬", "ββββ, ββββ, ββββ" | |
group_classes = "buttonsGroup insivibleButtonGroup" | |
dataset_name_classes = "topButton linear-background" | |
tags_classes = "bottomButton linear-background" | |
with gr.Group(elem_classes=group_classes) as button_group: | |
button_groups.append(button_group) | |
buttons.append(gr.Button(dataset_name, elem_classes=dataset_name_classes)) | |
buttons.append(gr.Button(tags, elem_classes=tags_classes)) | |
load_more_datasets = gr.Button("Load more datasets") # TODO: dosable when reaching end of page | |
gr.Markdown(f"_powered by [{model_id}](https://huggingface.co/{model_id})_") | |
with gr.Column(scale=4, min_width=0): | |
pass | |
with gr.Column(visible=False) as dataset_page: | |
with gr.Row(): | |
with gr.Column(scale=4, min_width=0): | |
pass | |
with gr.Column(scale=10): | |
dataset_title = gr.Markdown() | |
dataset_content = gr.Markdown() | |
with gr.Row(): | |
with gr.Column(scale=4, min_width=0): | |
pass | |
with gr.Column(): | |
generate_full_dataset_button = gr.Button("Generate Full Dataset", variant="primary") # TODO: implement | |
dataset_share_button = gr.Button("Share Dataset URL") | |
dataset_share_textbox = gr.Textbox(visible=False, show_copy_button=True, label="Copy this URL:", interactive=False, show_label=True) | |
back_button = gr.Button("< Back", size="sm") | |
with gr.Column(scale=4, min_width=0): | |
pass | |
with gr.Column(scale=4, min_width=0): | |
pass | |
app_state = gr.State({}) | |
T = TypeVar("T") | |
def batched(it: Iterable[T], n: int) -> Iterator[list[T]]: | |
it = iter(it) | |
while batch := list(islice(it, n)): | |
yield batch | |
def stream_reponse(msg: str, generated_texts: tuple[str] = (), max_tokens=500) -> Iterator[str]: | |
messages = [ | |
{"role": "user", "content": msg} | |
] + [ | |
item | |
for generated_text in generated_texts | |
for item in [ | |
{"role": "assistant", "content": generated_text}, | |
{"role": "user", "content": "Can you generate more ?"}, | |
] | |
] | |
for _ in range(3): | |
try: | |
for message in client.chat_completion( | |
messages=messages, | |
max_tokens=max_tokens, | |
stream=True, | |
top_p=0.8, | |
seed=42, | |
): | |
yield message.choices[0].delta.content | |
except requests.exceptions.ConnectionError as e: | |
print(e + "\n\nRetrying in 1sec") | |
time.sleep(1) | |
continue | |
break | |
def gen_datasets_line_by_line(search_query: str, generated_texts: tuple[str] = ()) -> Iterator[str]: | |
search_query = search_query or "" | |
search_query = search_query[:1000] if search_query.strip() else landing_page_query | |
generated_text = "" | |
current_line = "" | |
for token in stream_reponse( | |
GENERATE_DATASET_NAMES_FOR_SEARCH_QUERY.format(search_query=search_query), | |
generated_texts=generated_texts, | |
): | |
current_line += token | |
if current_line.endswith("\n"): | |
yield current_line | |
generated_text += current_line | |
current_line = "" | |
yield current_line | |
generated_text += current_line | |
print("-----\n\n" + generated_text) | |
def gen_dataset_content(search_query: str, dataset_name: str, tags: str) -> Iterator[str]: | |
search_query = search_query or "" | |
search_query = search_query[:1000] if search_query.strip() else landing_page_query | |
generated_text = "" | |
for token in stream_reponse(GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS.format( | |
search_query=search_query, | |
dataset_name=dataset_name, | |
tags=tags, | |
), max_tokens=1500): | |
generated_text += token | |
yield generated_text | |
print("-----\n\n" + generated_text) | |
def _search_datasets(search_query): | |
yield {generated_texts_state: [], app_state: {"search_query": search_query}} | |
yield { | |
button_group: gr.Group(elem_classes="buttonsGroup insivibleButtonGroup") | |
for button_group in button_groups[MAX_NB_ITEMS_PER_GENERATION_CALL:] | |
} | |
yield { | |
k: v | |
for dataset_name_button, tags_button in batched(buttons, 2) | |
for k, v in { | |
dataset_name_button: gr.Button("β¬β¬β¬β¬β¬β¬", elem_classes="topButton linear-background"), | |
tags_button: gr.Button("ββββ, ββββ, ββββ", elem_classes="bottomButton linear-background") | |
}.items() | |
} | |
current_item_idx = 0 | |
generated_text = "" | |
for line in gen_datasets_line_by_line(search_query): | |
if "I'm sorry" in line: | |
raise gr.Error("Error: inappropriate content") | |
if current_item_idx >= MAX_NB_ITEMS_PER_GENERATION_CALL: | |
return | |
if line.strip() and line.strip().split(".", 1)[0].isnumeric(): | |
try: | |
dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" (", 1) | |
except ValueError: | |
dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" ", 1) | |
dataset_name, tags = dataset_name.strip("()[]* "), tags.strip("()[]* ") | |
generated_text += line | |
yield { | |
buttons[2 * current_item_idx]: gr.Button(dataset_name, elem_classes="topButton"), | |
buttons[2 * current_item_idx + 1]: gr.Button(tags, elem_classes="bottomButton"), | |
generated_texts_state: (generated_text,), | |
} | |
current_item_idx += 1 | |
def search_dataset_from_search_button(search_query): | |
yield from _search_datasets(search_query) | |
def search_dataset_from_search_bar(search_query): | |
yield from _search_datasets(search_query) | |
def search_more_datasets(search_query, generated_texts): | |
current_item_idx = initial_item_idx = len(generated_texts) * MAX_NB_ITEMS_PER_GENERATION_CALL | |
yield { | |
button_group: gr.Group(elem_classes="buttonsGroup") | |
for button_group in button_groups[len(generated_texts) * MAX_NB_ITEMS_PER_GENERATION_CALL:(len(generated_texts) + 1) * MAX_NB_ITEMS_PER_GENERATION_CALL] | |
} | |
generated_text = "" | |
for line in gen_datasets_line_by_line(search_query, generated_texts=generated_texts): | |
if "I'm sorry" in line or "against Microsoft's use case policy" in line: | |
raise gr.Error("Error: inappropriate content") | |
if current_item_idx - initial_item_idx >= MAX_NB_ITEMS_PER_GENERATION_CALL: | |
return | |
if line.strip() and line.strip().split(".", 1)[0].isnumeric(): | |
try: | |
dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" (", 1) | |
except ValueError: | |
dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" ", 1) [0], "" | |
dataset_name, tags = dataset_name.strip("()[]* "), tags.strip("()[]* ") | |
generated_text += line | |
yield { | |
buttons[2 * current_item_idx]: gr.Button(dataset_name, elem_classes="topButton"), | |
buttons[2 * current_item_idx + 1]: gr.Button(tags, elem_classes="bottomButton"), | |
generated_texts_state: (*generated_texts, generated_text), | |
} | |
current_item_idx += 1 | |
def _show_dataset(search_query, dataset_name, tags): | |
yield { | |
search_page: gr.Column(visible=False), | |
dataset_page: gr.Column(visible=True), | |
dataset_title: f"# {dataset_name}\n\n tags: {tags}\n\n _Note: This is an AI-generated dataset so its content may be inaccurate or false_", | |
dataset_share_textbox: gr.Textbox(visible=False), | |
app_state: { | |
"search_query": search_query, | |
"dataset_name": dataset_name, | |
"tags": tags | |
} | |
} | |
for generated_text in gen_dataset_content(search_query=search_query, dataset_name=dataset_name, tags=tags): | |
yield {dataset_content: generated_text} | |
show_dataset_inputs = [search_bar, *buttons] | |
show_dataset_outputs = [app_state, search_page, dataset_page, dataset_title, dataset_content, dataset_share_textbox] | |
scroll_to_top_js = """ | |
function (...args) { | |
console.log(args); | |
if ('parentIFrame' in window) { | |
window.parentIFrame.scrollTo({top: 0, behavior:'smooth'}); | |
} else { | |
window.scrollTo({ top: 0 }); | |
} | |
return args; | |
} | |
""".replace("len(show_dataset_inputs)", str(len(show_dataset_inputs))) | |
def show_dataset_from_button(search_query, *buttons_values, i): | |
dataset_name, tags = buttons_values[2 * i : 2 * i + 2] | |
yield from _show_dataset(search_query, dataset_name, tags) | |
for i, (dataset_name_button, tags_button) in enumerate(batched(buttons, 2)): | |
dataset_name_button.click(partial(show_dataset_from_button, i=i), inputs=show_dataset_inputs, outputs=show_dataset_outputs, js=scroll_to_top_js) | |
tags_button.click(partial(show_dataset_from_button, i=i), inputs=show_dataset_inputs, outputs=show_dataset_outputs, js=scroll_to_top_js) | |
def show_search_page(): | |
return gr.Column(visible=True), gr.Column(visible=False) | |
def generate_full_dataset(): | |
raise gr.Error("Not implemented yet sorry ! Request your dataset in the Discussion tab (provide the dataset URL)") | |
def show_dataset_url(state): | |
return gr.Textbox( | |
f"{URL}?q={state['search_query'].replace(' ', '+')}&dataset={state['dataset_name']}&tags={state['tags']}", | |
visible=True, | |
) | |
def load_app(request: gr.Request): | |
query_params = dict(request.query_params) | |
if "dataset" in query_params: | |
yield from _show_dataset( | |
search_query=query_params.get("q", query_params["dataset"]), | |
dataset_name=query_params["dataset"], | |
tags=query_params.get("tags", "") | |
) | |
elif "q" in query_params: | |
yield {search_bar: query_params["q"]} | |
yield from _search_datasets(query_params["q"]) | |
else: | |
yield {search_page: gr.Column(visible=True)} | |
demo.launch() | |