|
import ast |
|
import glob |
|
from itertools import islice |
|
from functools import partial |
|
from typing import Optional, Type |
|
|
|
import gradio as gr |
|
import nltk |
|
import pandas as pd |
|
from datatrove.data import Document |
|
from datatrove.executor.local import LocalPipelineExecutor |
|
from datatrove.pipeline.extractors import Trafilatura |
|
from datatrove.pipeline.filters.base_filter import BaseFilter |
|
from datatrove.pipeline.filters import ( |
|
C4QualityFilter, |
|
FineWebQualityFilter, |
|
GopherQualityFilter, |
|
GopherRepetitionFilter, |
|
LanguageFilter, |
|
URLFilter, |
|
) |
|
from datatrove.pipeline.formatters import PIIFormatter |
|
from datatrove.pipeline.readers import JsonlReader, WarcReader |
|
from datatrove.utils.typeshelper import Languages |
|
|
|
|
|
nltk.download('punkt_tab') |
|
DUMP_TO_PROCESS = "CC-MAIN-2023-50" |
|
default_output_docs_2k = pd.read_json(f"output_all-2k/base_processing/output/{DUMP_TO_PROCESS}/00000.jsonl.gz", compression="gzip", lines=True).to_dict(orient="records") |
|
default_output_docs_200 = pd.read_json(f"output_all-200/base_processing/output/{DUMP_TO_PROCESS}/00000.jsonl.gz", compression="gzip", lines=True).to_dict(orient="records") |
|
|
|
make_gallery_image_buttons_js = """ |
|
function load() { |
|
class ClassWatcher { |
|
|
|
constructor(targetNode, classToWatch, classAddedCallback, arg) { |
|
this.targetNode = targetNode |
|
this.classToWatch = classToWatch |
|
this.classAddedCallback = classAddedCallback |
|
this.arg = arg |
|
this.observer = null |
|
this.lastClassState = targetNode.classList.contains(this.classToWatch) |
|
|
|
this.init() |
|
} |
|
|
|
init() { |
|
this.observer = new MutationObserver(this.mutationCallback) |
|
this.observe() |
|
} |
|
|
|
observe() { |
|
this.observer.observe(this.targetNode, { attributes: true }) |
|
} |
|
|
|
disconnect() { |
|
this.observer.disconnect() |
|
} |
|
|
|
mutationCallback = mutationsList => { |
|
for (let mutation of mutationsList) { |
|
if (mutation.type === 'attributes' && mutation.attributeName === 'class') { |
|
let currentClassState = mutation.target.classList.contains(this.classToWatch) |
|
if(this.lastClassState !== currentClassState) { |
|
this.lastClassState = currentClassState |
|
if(currentClassState) { |
|
this.classAddedCallback(this.arg) |
|
} |
|
} |
|
} |
|
} |
|
} |
|
} |
|
let buttons = document.getElementsByClassName("block-button"); |
|
function clickButton(i) { |
|
buttons[i].click(); |
|
} |
|
Array.from(document.getElementById("pipeline-gallery").getElementsByClassName("thumbnail-item")).map( |
|
(b, i) => new ClassWatcher(b, 'selected', clickButton, i) |
|
) |
|
} |
|
""" |
|
css = """ |
|
tr:has(> td div span span div.diffInsertion) { |
|
background: darkgreen; |
|
} |
|
tr:has(> td div span span div.diffDeletion) { |
|
background: darkred; |
|
} |
|
tr td { |
|
border-top: 1px solid black; |
|
} |
|
.grid-container { |
|
gap: 0; |
|
grid-template-rows: auto; |
|
grid-auto-rows: auto; |
|
} |
|
.thumbnail-item { |
|
aspect-ratio: auto; |
|
height: min-content; |
|
} |
|
.grid-wrap { |
|
min-height: 0; |
|
} |
|
""" |
|
|
|
|
|
blocks = sorted(glob.glob("images/*.png")) |
|
|
|
|
|
def prepare_as_list_or_none(text: str) -> Optional[list[str]]: |
|
return ([x.strip() for x in text.split(",") if x.strip()] or None) if text else None |
|
|
|
def non_empty_list_or_none(input_list: list[str]) -> Optional[list[str]]: |
|
return input_list or None |
|
|
|
def build_code_snippet(steps, params=None): |
|
|
|
return ( |
|
"```python\n" |
|
"TODO\n" |
|
"```" |
|
) |
|
|
|
|
|
with gr.Blocks(css=css, js=make_gallery_image_buttons_js) as demo: |
|
state = gr.State({"selected_block": 0}) |
|
gr.Markdown("# Common Crawl Pipeline Creator") |
|
gallery = gr.Gallery( |
|
blocks, |
|
columns=4, |
|
rows=2, |
|
label="Select step to edit", |
|
object_fit="scale-down", |
|
show_share_button=False, |
|
show_download_button=False, |
|
show_fullscreen_button=False, |
|
elem_id="pipeline-gallery", |
|
allow_preview=False, |
|
) |
|
gallery_image_buttons = [gr.Button(visible=False, elem_classes="block-button") for _ in blocks] |
|
blocks_uis = [] |
|
with gr.Column(visible=False) as col: |
|
blocks_uis.append(col) |
|
gr.Markdown("## 1. URL Filtering \n\nPerforms filtering based on samples urls.") |
|
with gr.Group(): |
|
url_filtering_checkbox = gr.Checkbox(True, label="Enable") |
|
with gr.Accordion("Parameters", open=True) as acc: |
|
use_integrated_lists_checkbox = gr.Checkbox(True, label="use_integrated_lists", info="use the datatrove integrated lists of banned urls and words") |
|
with gr.Row(): |
|
with gr.Column(): |
|
extra_domain_textbox = gr.Textbox("", label="extra_domains", info="remove if the domain is present in `extra_domains`") |
|
extra_domain_textbox.prepare_parameter = prepare_as_list_or_none |
|
extra_urls_textbox = gr.Textbox("", label="extra_urls", info="remove if the full url is present on `extra_urls`") |
|
extra_urls_textbox.prepare_parameter = prepare_as_list_or_none |
|
with gr.Column(): |
|
banned_words_textbox = gr.Textbox("", label="banned_words", info="remove if any word from `banned_words` is in the url") |
|
banned_words_textbox.prepare_parameter = prepare_as_list_or_none |
|
banned_subwords_textbox = gr.Textbox("", label="banned_subwords", info="remove if any word from `banned_subwords` is a substring of the url") |
|
banned_subwords_textbox.prepare_parameter = prepare_as_list_or_none |
|
with gr.Column(): |
|
soft_banned_words_textbox = gr.Textbox("", label="soft_banned_words", info="remove if there are at least `soft_word_threshold` words from `soft_banned_words` in the url") |
|
soft_banned_words_textbox.prepare_parameter = prepare_as_list_or_none |
|
soft_word_threshold_slider = gr.Slider(0, 5, value=2, step=1, label="soft_word_threshold", info="remove if there are at least `soft_word_threshold` words from `soft_banned_words` in the url") |
|
url_filtering_checkbox.change(lambda visible: gr.Accordion(visible=visible), inputs=url_filtering_checkbox, outputs=acc) |
|
url_filtering_parameters_components = [use_integrated_lists_checkbox, extra_domain_textbox, extra_urls_textbox, banned_words_textbox, banned_subwords_textbox, soft_banned_words_textbox, soft_word_threshold_slider] |
|
with gr.Column(visible=False) as col: |
|
blocks_uis.append(col) |
|
gr.Markdown("## 2. Text Extraction \n\nUses the [Trafilatura](https://trafilatura.readthedocs.io) extractor.") |
|
with gr.Group(): |
|
text_extraction_checkbox = gr.Checkbox(True, label="Enable") |
|
with gr.Accordion("Parameters", open=True) as acc: |
|
with gr.Row(): |
|
favour_precision_checkbox = gr.Checkbox(True, label="favour_precision", info="prefer less text but correct extraction") |
|
timeout_slider = gr.Slider(0.05, 0.5, value=0.1, step=0.05, label="timeout", info="the timeout for extraction, per document, in seconds") |
|
deduplicate_checkbox = gr.Checkbox(True, label="deduplicate", info="trafilatura's deduplicate option") |
|
text_extraction_checkbox.change(lambda visible: gr.Accordion(visible=visible), inputs=text_extraction_checkbox, outputs=acc) |
|
text_extraction_parameters_components = [favour_precision_checkbox, timeout_slider, deduplicate_checkbox] |
|
with gr.Column(visible=False) as col: |
|
blocks_uis.append(col) |
|
gr.Markdown("## 3. Language Filtering \n\nUses the [fastext](https://fasttext.cc/docs/en/language-identification.html) language identification models.") |
|
with gr.Group(): |
|
language_filtering_checkbox = gr.Checkbox(True, label="Enable") |
|
with gr.Accordion("Parameters", open=True) as acc: |
|
with gr.Row(): |
|
languages_textbox = gr.Dropdown(sorted(v for k, v in vars(Languages).items() if not k.startswith("__")), multiselect=True, label="languages", info="list of languages to keep. empty for all") |
|
languages_textbox.prepare_parameter = non_empty_list_or_none |
|
language_threshold_slider = gr.Slider(0, 1, value=0.65, step=0.05, label="language_threshold", info="minimum score to accept a document") |
|
language_filtering_checkbox.change(lambda visible: gr.Accordion(visible=visible), inputs=language_filtering_checkbox, outputs=acc) |
|
language_filtering_parameters_components = [languages_textbox, language_threshold_slider] |
|
with gr.Column(visible=False) as col: |
|
blocks_uis.append(col) |
|
gr.Markdown("## 4. Gopher Filtering (repetitions) \n\nUses the [Gopher](https://huggingface.co/papers/2112.11446) text repetition filters.") |
|
with gr.Group(): |
|
gopher_filtering_repetitions_checkbox = gr.Checkbox(True, label="Enable") |
|
with gr.Accordion("Parameters", open=True) as acc: |
|
with gr.Group(): |
|
with gr.Row(): |
|
language_dropdown1 = gr.Dropdown(sorted(v for k, v in vars(Languages).items() if not k.startswith("__")), value=Languages.english, label="language", info="tokenizer language") |
|
top_n_grams_textbox = gr.Textbox("(2, 0.2), (3, 0.18), (4, 0.16)", label="top_n_grams") |
|
top_n_grams_textbox.prepare_parameter = ast.literal_eval |
|
dup_n_grams_textbox = gr.Textbox("(5, 0.15), (6, 0.14), (7, 0.13), (8, 0.12), (9, 0.11), (10, 0.10)", label="dup_n_grams") |
|
dup_n_grams_textbox.prepare_parameter = ast.literal_eval |
|
with gr.Row(): |
|
dup_line_frac_slider = gr.Slider(0, 1, value=0.3, step=0.05, label="dup_line_frac") |
|
dup_para_frac_slider = gr.Slider(0, 1, value=0.3, step=0.05, label="dup_para_frac") |
|
dup_line_char_frac_slider = gr.Slider(0, 1, value=0.2, step=0.05, label="dup_line_char_frac") |
|
dup_para_char_frac_slider = gr.Slider(0, 1, value=0.2, step=0.05, label="dup_para_char_frac") |
|
gopher_filtering_repetitions_checkbox.change(lambda visible: gr.Accordion(visible=visible), inputs=gopher_filtering_repetitions_checkbox, outputs=acc) |
|
gopher_filtering_repetitions_parameters_components = [language_dropdown1, top_n_grams_textbox, dup_n_grams_textbox, dup_line_frac_slider, dup_para_frac_slider, dup_line_char_frac_slider, dup_para_char_frac_slider] |
|
with gr.Column(visible=False) as col: |
|
blocks_uis.append(col) |
|
gr.Markdown("## 8. PII Removal \n\nReplaces email addresses and ip addresses in the document text.") |
|
with gr.Group(): |
|
pii_removal_checkbox = gr.Checkbox(True, label="Enable") |
|
with gr.Accordion("Parameters", open=True) as acc: |
|
with gr.Row(): |
|
remove_emails_checkbox = gr.Checkbox(True, label="remove_emails", info="Replace email addresses") |
|
remove_ips_checkbox = gr.Checkbox(True, label="remove_ips", info="Replace IP addresses") |
|
only_remove_public_ips_checkbox = gr.Checkbox(True, label="only_remove_public_ips", info="by default we only replace public (and thus PII) IPs") |
|
with gr.Row(): |
|
email_replacement_textbox = gr.Textbox("email@example.com, firstname.lastname@example.org", label="email_replacement", info="strings to use as replacement. They will be used in a circular way") |
|
email_replacement_textbox.prepare_parameter = prepare_as_list_or_none |
|
ip_replacement_textbox = gr.Textbox("22.214.171.124, 126.96.36.199, 188.8.131.52, 184.108.40.206, 220.127.116.11, 18.104.22.168", label="ip_replacement", info="same as email_replacement but for IP addresses") |
|
ip_replacement_textbox.prepare_parameter = prepare_as_list_or_none |
|
pii_removal_checkbox.change(lambda visible: gr.Accordion(visible=visible), inputs=pii_removal_checkbox, outputs=acc) |
|
pii_removal_parameters_components = [remove_emails_checkbox, remove_ips_checkbox, only_remove_public_ips_checkbox, email_replacement_textbox, ip_replacement_textbox] |
|
with gr.Column(visible=False) as col: |
|
blocks_uis.append(col) |
|
gr.Markdown("## 7. Custom Filters \n\nUses the [FineWeb](https://huggingface.co/datasets/HuggingFaceFW/fineweb) custom text filters.") |
|
with gr.Group(): |
|
custom_filters_checkbox = gr.Checkbox(True, label="Enable") |
|
with gr.Accordion("Parameters", open=True) as acc: |
|
with gr.Row(): |
|
line_punct_thr_slider = gr.Slider(0, 1, value=0.12, step=0.01, label="line_punct_thr") |
|
line_punct_exclude_zero = gr.Checkbox(False, label="line_punct_exclude_zero") |
|
short_line_thr_slider = gr.Slider(0, 1, value=0.67, step=0.01, label="short_line_thr") |
|
short_line_length_slider = gr.Slider(0, 100, value=30, step=1, label="short_line_length") |
|
char_duplicates_ratio_slider = gr.Slider(0, 1, value=0.01, step=0.01, label="char_duplicates_ratio") |
|
new_line_ratio_slider = gr.Slider(0, 1, value=0.3, step=0.01, label="new_line_ratio") |
|
custom_filters_checkbox.change(lambda visible: gr.Accordion(visible=visible), inputs=custom_filters_checkbox, outputs=acc) |
|
custom_filters_parameters_components = [line_punct_thr_slider, line_punct_exclude_zero, short_line_thr_slider, short_line_length_slider, char_duplicates_ratio_slider, new_line_ratio_slider] |
|
with gr.Column(visible=False) as col: |
|
blocks_uis.append(col) |
|
gr.Markdown("## 6. C4 Filters\n\nUses the [C4](https://huggingface.co/datasets/allenai/c4) text size and content filters.") |
|
with gr.Group(): |
|
c4_filters_checkbox = gr.Checkbox(True, label="Enable") |
|
with gr.Accordion(" Parameters", open=True) as acc: |
|
with gr.Group(): |
|
with gr.Row(): |
|
split_paragraph_checkbox = gr.Checkbox(True, label="split_paragraph", info="disable to apply the filters to each sentence instead of to each line") |
|
with gr.Row(): |
|
language_dropdown2 = gr.Dropdown(sorted(v for k, v in vars(Languages).items() if not k.startswith("__")), value=Languages.english, label="language", info="tokenizer language") |
|
min_num_sentences_slider = gr.Slider(0, 10, value=5, step=1, label="min_num_sentences", info="remove documents that do not have at least this number of sentences (after line filtering)") |
|
min_words_per_line_slider = gr.Slider(0, 10, value=3, step=1, label="min_words_per_line", info="drop lines without this min number of words") |
|
max_word_length_slider = gr.Slider(0, 2000, value=1000, step=10, label="max_word_length", info=" drop lines where at least one word has more than this number of characters") |
|
with gr.Row(): |
|
remove_citations_checkbox = gr.Checkbox(True, label="remove_citations", info="remove wikipedia style citations from the text") |
|
filter_no_terminal_punct_checkbox = gr.Checkbox(True, label="filter_no_terminal_punct", info="remove lines without terminal punctuation marks") |
|
filter_lorem_ipsum_checkbox = gr.Checkbox(True, label="filter_lorem_ipsum", info="drop documents that contain 'lorem ipsum'") |
|
filter_javascript_checkbox = gr.Checkbox(True, label="filter_javascript", info="drop lines mentioning 'javascript'") |
|
filter_curly_bracket = gr.Checkbox(True, label="filter_curly_bracket", info="drop documents containing {") |
|
filter_policy = gr.Checkbox(True, label="filter_policy", info="drop lines containing any of the policy phrases (e.g. 'terms of use', 'use cookies')") |
|
c4_filters_checkbox.change(lambda visible: gr.Accordion(visible=visible), inputs=c4_filters_checkbox, outputs=acc) |
|
c4_filters_parameters_components = [split_paragraph_checkbox, language_dropdown2, min_num_sentences_slider, min_words_per_line_slider, max_word_length_slider, remove_citations_checkbox, filter_no_terminal_punct_checkbox, filter_lorem_ipsum_checkbox, filter_javascript_checkbox, filter_curly_bracket, filter_policy] |
|
with gr.Column(visible=False) as col: |
|
blocks_uis.append(col) |
|
gr.Markdown("## 5. Gopher Filtering (quality) \n\nUses the [Gopher](https://huggingface.co/papers/2112.11446) text quality filters.") |
|
with gr.Group(): |
|
gopher_filtering_quality_checkbox = gr.Checkbox(True, label="Enable") |
|
with gr.Accordion("Parameters", open=True) as acc: |
|
with gr.Group(): |
|
with gr.Row(): |
|
language_dropdown2 = gr.Dropdown(sorted(v for k, v in vars(Languages).items() if not k.startswith("__")), value=Languages.english, label="language", info="tokenizer language") |
|
min_doc_words_slider = gr.Slider(0, 1000, value=50, step=10, label="min_doc_words") |
|
max_doc_words_slider = gr.Slider(0, 200_000, value=100_000, step=10_000, label="max_doc_words") |
|
with gr.Row(): |
|
min_avg_word_length_slider = gr.Slider(0, 20, value=3, step=1, label="min_avg_word_length") |
|
max_avg_word_length_slider = gr.Slider(0, 20, value=10, step=1, label="max_avg_word_length") |
|
with gr.Row(): |
|
max_symbol_word_ratio_slider = gr.Slider(0, 1, value=0.1, step=0.05, label="max_symbol_word_ratio") |
|
max_bullet_lines_ratio_slider = gr.Slider(0, 1, value=0.9, step=0.05, label="max_bullet_lines_ratio") |
|
max_ellipsis_lines_ratio_slider = gr.Slider(0, 1, value=0.3, step=0.05, label="max_ellipsis_lines_ratio") |
|
max_non_alpha_words_ratio_slider = gr.Slider(0, 1, value=0.8, step=0.05, label="max_non_alpha_words_ratio") |
|
with gr.Row(): |
|
min_stop_words_slider = gr.Slider(0, 10, value=2, step=1, label="min_stop_words") |
|
stop_words_textbox = gr.Textbox("the, be, to, of, and, that, have, with", label="stop_words") |
|
stop_words_textbox.prepare_parameter = prepare_as_list_or_none |
|
gopher_filtering_quality_checkbox.change(lambda visible: gr.Accordion(visible=visible), inputs=gopher_filtering_quality_checkbox, outputs=acc) |
|
gopher_filtering_quality_parameters_components = [language_dropdown2, min_doc_words_slider, max_doc_words_slider, min_avg_word_length_slider, max_avg_word_length_slider, max_symbol_word_ratio_slider, max_bullet_lines_ratio_slider, max_ellipsis_lines_ratio_slider, max_non_alpha_words_ratio_slider, min_stop_words_slider, stop_words_textbox] |
|
|
|
with gr.Row(): |
|
view_pipeline_results_button = gr.Button("Run Pipeline & Stream Results", variant="primary", scale=4) |
|
stop_button = gr.Button("Stop") |
|
|
|
steps = [ |
|
URLFilter, |
|
Trafilatura, |
|
LanguageFilter, |
|
GopherRepetitionFilter, |
|
GopherQualityFilter, |
|
C4QualityFilter, |
|
FineWebQualityFilter, |
|
PIIFormatter |
|
] |
|
steps_parameters_components = [ |
|
url_filtering_parameters_components, |
|
text_extraction_parameters_components, |
|
language_filtering_parameters_components, |
|
gopher_filtering_repetitions_parameters_components, |
|
gopher_filtering_quality_parameters_components, |
|
c4_filters_parameters_components, |
|
custom_filters_parameters_components, |
|
pii_removal_parameters_components |
|
] |
|
|
|
with gr.Tab("Output") as output_tab: |
|
output_dataframe = gr.DataFrame(datatype="markdown") |
|
with gr.Tab("Excluded") as excluded_tab: |
|
excluded_dataframes: dict[Type, gr.DataFrame] = {} |
|
excluded_tabs: dict[Type, gr.Tab] = {} |
|
for step in steps: |
|
if issubclass(step, BaseFilter) and step is not URLFilter: |
|
with gr.Tab(step.__name__) as t: |
|
excluded_dataframes[step] = gr.DataFrame(datatype="markdown") |
|
excluded_tabs[step] = t |
|
with gr.Tab("Python code") as code_tab: |
|
python_code_markdown = gr.Markdown(build_code_snippet(steps)) |
|
|
|
|
|
gr.Markdown("_powered by [datatrove](https://github.com/huggingface/datatrove)_") |
|
|
|
def show_block_ui(i): |
|
return {**{block_ui: gr.Column(visible=(j == i)) for j, block_ui in enumerate(blocks_uis)}, state: {"selected_block": i}} |
|
|
|
for i, button in enumerate(gallery_image_buttons): |
|
button.click(partial(show_block_ui, i), outputs=blocks_uis + [state]) |
|
|
|
|
|
inputs = [ |
|
url_filtering_checkbox, |
|
text_extraction_checkbox, |
|
language_filtering_checkbox, |
|
gopher_filtering_repetitions_checkbox, |
|
gopher_filtering_quality_checkbox, |
|
c4_filters_checkbox, |
|
custom_filters_checkbox, |
|
pii_removal_checkbox |
|
] + sum(steps_parameters_components, []) |
|
|
|
@view_pipeline_results_button.click(inputs=inputs, outputs=[output_tab, output_dataframe, excluded_tab] + list(excluded_dataframes.values()) + list(excluded_tabs.values())) |
|
def view_pipeline_results(*args): |
|
enable_steps, steps_parameters = args[:len(steps)], args[len(steps):] |
|
steps_parameters_iter = iter(steps_parameters) |
|
steps_parameters = [ |
|
{ |
|
parameters_component.label: parameters_component.prepare_parameter(parameter) if hasattr(parameters_component, "prepare_parameter") else parameter |
|
for parameters_component, parameter in zip(step_parameters_components, steps_parameters_iter) |
|
} |
|
for step_parameters_components in steps_parameters_components |
|
] |
|
default_steps_parameters = [ |
|
{ |
|
parameters_component.label: parameters_component.prepare_parameter(parameters_component.value) if hasattr(parameters_component, "prepare_parameter") else parameters_component.value |
|
for parameters_component in step_parameters_components |
|
} |
|
for step_parameters_components in steps_parameters_components |
|
] |
|
|
|
class ExclusionWriter: |
|
|
|
def __init__(self) -> None: |
|
self.docs: list[Document] = [] |
|
|
|
def __enter__(self): |
|
return self |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
return |
|
|
|
def write(self, doc, rank): |
|
self.docs.append(doc) |
|
|
|
steps_to_run = [ |
|
step(**step_parameters, **({"exclusion_writer": ExclusionWriter()} if step in excluded_dataframes else {})) |
|
for step, step_parameters, enable_step in zip(steps, steps_parameters, enable_steps) |
|
if enable_step |
|
] |
|
output_docs: list[Document] = [] |
|
num_warc_samples = 0 |
|
|
|
def increment_num_warc_samples(data, rank, world_size, num_warc_samples_per_doc=1): |
|
nonlocal num_warc_samples |
|
for x in data: |
|
num_warc_samples += num_warc_samples_per_doc |
|
yield x |
|
|
|
if steps_parameters[:2] == default_steps_parameters[:2] and all(enable_steps[:2]): |
|
|
|
pipeline_executor = LocalPipelineExecutor( |
|
pipeline=[ |
|
JsonlReader(data_folder=f"output_text_extraction-2k/base_processing/output/{DUMP_TO_PROCESS}", glob_pattern="*.jsonl.gz"), |
|
partial(increment_num_warc_samples, num_warc_samples_per_doc=2000 / 1687) |
|
] + steps_to_run[2:] + [ |
|
lambda data, rank, world_size: map(output_docs.append, data) |
|
], |
|
logging_dir="logs", |
|
skip_completed=False |
|
) |
|
else: |
|
pipeline_executor = LocalPipelineExecutor( |
|
pipeline=[ |
|
WarcReader(data_folder="data", glob_pattern="*.warc.gz"), |
|
lambda data, rank, world_size: islice(data, num_warc_samples), |
|
] + steps_to_run + [ |
|
lambda data, rank, world_size: map(output_docs.append, data) |
|
], |
|
logging_dir="logs", |
|
skip_completed=False |
|
) |
|
from threading import Thread |
|
thread = Thread(target=pipeline_executor.run) |
|
thread.start() |
|
while thread.is_alive(): |
|
thread.join(timeout=1) |
|
|
|
if num_warc_samples: |
|
yield { |
|
output_tab: gr.Tab(f"Output (~{len(output_docs)/num_warc_samples*100:.03f}% of data)"), |
|
excluded_tab: gr.Tab(f"Excluded (~{100 - len(output_docs)/num_warc_samples*100:.03f}% of data)"), |
|
output_dataframe: pd.DataFrame({"text": [doc.text for doc in output_docs]}), |
|
**{ |
|
excluded_dataframes[type(step_to_run)]: pd.DataFrame({"text": [doc.text for doc in step_to_run.exclusion_writer.docs]}) |
|
for step_to_run in pipeline_executor.pipeline |
|
if isinstance(step_to_run, BaseFilter) and type(step_to_run) in excluded_dataframes |
|
}, |
|
**{ |
|
excluded_tabs[type(step_to_run)]: gr.Tab(f"{type(step_to_run).__name__} (~{len(step_to_run.exclusion_writer.docs)/num_warc_samples*100:.03f}% of data)") |
|
for step_to_run in pipeline_executor.pipeline |
|
if isinstance(step_to_run, BaseFilter) and type(step_to_run) in excluded_dataframes |
|
}, |
|
} |
|
else: |
|
yield { |
|
output_tab: gr.Tab("Output (loading...)"), |
|
excluded_tab: gr.Tab("Excluded (loading...)"), |
|
**{ |
|
excluded_dataframes[type(step_to_run)]: pd.DataFrame({"text": [doc.text for doc in step_to_run.exclusion_writer.docs]}) |
|
for step_to_run in pipeline_executor.pipeline |
|
if isinstance(step_to_run, BaseFilter) and type(step_to_run) in excluded_dataframes |
|
}, |
|
**{ |
|
excluded_tabs[type(step_to_run)]: gr.Tab(f"{type(step_to_run).__name__} (~{len(step_to_run.exclusion_writer.docs)/num_warc_samples*100:.03f}% of data)") |
|
for step_to_run in pipeline_executor.pipeline |
|
if isinstance(step_to_run, BaseFilter) and type(step_to_run) in excluded_dataframes |
|
}, |
|
} |
|
yield { |
|
output_tab: gr.Tab(f"Output (~{len(output_docs)/num_warc_samples*100:.03f}% of data)"), |
|
excluded_tab: gr.Tab(f"Excluded (~{100 - len(output_docs)/num_warc_samples*100:.03f}% of data)"), |
|
output_dataframe: pd.DataFrame({"text": [doc.text for doc in output_docs]}), |
|
**{ |
|
excluded_dataframes[type(step_to_run)]: pd.DataFrame({"text": [doc.text for doc in step_to_run.exclusion_writer.docs]}) |
|
for step_to_run in pipeline_executor.pipeline |
|
if isinstance(step_to_run, BaseFilter) and type(step_to_run) in excluded_dataframes |
|
}, |
|
**{ |
|
excluded_tabs[type(step_to_run)]: gr.Tab(f"{type(step_to_run).__name__} (~{len(step_to_run.exclusion_writer.docs)/num_warc_samples*100:.03f}% of data)") |
|
for step_to_run in pipeline_executor.pipeline |
|
if isinstance(step_to_run, BaseFilter) and type(step_to_run) in excluded_dataframes |
|
}, |
|
} |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|