File size: 5,536 Bytes
638184c
40e38d3
 
75448af
40e38d3
 
 
 
276d919
40e38d3
 
 
 
75448af
 
40e38d3
75448af
 
40e38d3
276d919
 
 
 
 
 
 
75448af
276d919
75448af
276d919
40e38d3
638184c
 
75448af
 
638184c
40e38d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75448af
40e38d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75448af
40e38d3
 
 
 
 
 
 
 
 
 
 
75448af
40e38d3
 
 
 
 
 
 
 
 
 
 
 
 
75448af
 
 
 
 
 
 
 
 
 
 
276d919
75448af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from functools import lru_cache, partial
import os
import json
import re
import tempfile
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict
from datatrove.io import get_datafolder, _get_true_fs
from datatrove.utils.stats import MetricStatsDict
import gradio as gr
import tenacity

from src.logic.graph_settings import Grouping

def find_folders(base_folder: str, path: str) -> List[str]:
    base_folder_df = get_datafolder(base_folder)
    if not base_folder_df.exists(path):
        return []

    from huggingface_hub import HfFileSystem
    extra_options = {}
    if isinstance(_get_true_fs(base_folder_df.fs), HfFileSystem):
        extra_options["expand_info"] = False  # speed up

    return (
            folder
            for folder,info in base_folder_df.find(path, maxdepth=1, withdirs=True, detail=True, **extra_options).items()
            if info["type"] == "directory" and not (folder.rstrip("/") == path.rstrip("/"))
        )

def fetch_datasets(base_folder: str, progress=gr.Progress()):
    datasets = sorted(progress.tqdm(find_folders(base_folder, "")))
    if len(datasets) == 0:
        raise ValueError("No datasets found")
    return datasets, None

def fetch_groups(base_folder: str, datasets: List[str], old_groups: str, type: str = "intersection"):
    if not datasets:
        return gr.update(choices=[], value=None)

    with ThreadPoolExecutor() as executor:
        GROUPS = list(executor.map(lambda run: [Path(x).name for x in find_folders(base_folder, run)], datasets))
    if len(GROUPS) == 0:
        return gr.update(choices=[], value=None)

    if type == "intersection":
        new_choices = set.intersection(*(set(g) for g in GROUPS))
    else:
        new_choices = set.union(*(set(g) for g in GROUPS))
    value = None
    if old_groups:
        value = list(set.intersection(new_choices, {old_groups}))
        value = value[0] if value else None

    if not value and len(new_choices) == 1:
        value = list(new_choices)[0]

    return gr.Dropdown(choices=sorted(list(new_choices)), value=value)

def fetch_metrics(base_folder: str, datasets: List[str], group: str, old_metrics: str, type: str = "intersection"):
    if not group:
        return gr.update(choices=[], value=None)

    with ThreadPoolExecutor() as executor:
        metrics = list(
            executor.map(lambda run: [Path(x).name for x in find_folders(base_folder, f"{run}/{group}")], datasets))
    if len(metrics) == 0:
        return gr.update(choices=[], value=None)

    if type == "intersection":
        new_possibles_choices = set.intersection(*(set(s) for s in metrics))
    else:
        new_possibles_choices = set.union(*(set(s) for s in metrics))
    value = None
    if old_metrics:
        value = list(set.intersection(new_possibles_choices, {old_metrics}))
        value = value[0] if value else None

    if not value and len(new_possibles_choices) == 1:
        value = list(new_possibles_choices)[0]

    return gr.Dropdown(choices=sorted(list(new_possibles_choices)), value=value)

def reverse_search(base_folder: str, possible_datasets: List[str], grouping: str, metric_name: str) -> str:
    with ThreadPoolExecutor() as executor:
        found_datasets = list(executor.map(
            lambda dataset: dataset if metric_exists(base_folder, dataset, metric_name, grouping) else None,
            possible_datasets))
    found_datasets = [dataset for dataset in found_datasets if dataset is not None]
    return "\n".join(found_datasets)

def reverse_search_add(datasets: List[str], reverse_search_results: str) -> List[str]:
    datasets = datasets or []
    return list(set(datasets + reverse_search_results.strip().split("\n")))

def metric_exists(base_folder: str, path: str, metric_name: str, group_by: str) -> bool:
    base_folder = get_datafolder(base_folder)
    return base_folder.exists(f"{path}/{group_by}/{metric_name}/metric.json")

@tenacity.retry(stop=tenacity.stop_after_attempt(5))
def load_metrics(base_folder: str, path: str, metric_name: str, group_by: str) -> MetricStatsDict:
    base_folder = get_datafolder(base_folder)
    with base_folder.open(f"{path}/{group_by}/{metric_name}/metric.json") as f:
        json_metric = json.load(f)
        return MetricStatsDict.from_dict(json_metric)

def load_data(dataset_path: str, base_folder: str, grouping: str, metric_name: str) -> MetricStatsDict:
    return load_metrics(base_folder, dataset_path, metric_name, grouping)


def fetch_graph_data(
        base_folder: str,
        datasets: List[str],
        metric_name: str,
        grouping: Grouping,
        progress=gr.Progress(),
):
    if len(datasets) <= 0 or not metric_name or not grouping:
        return None, None

    with ThreadPoolExecutor() as pool:
        data = list(
            progress.tqdm(
                pool.map(
                    partial(load_data, base_folder=base_folder, metric_name=metric_name, grouping=grouping),
                    datasets,
                ),
                total=len(datasets),
                desc="Loading data...",
            )
        )

    data = {path: result for path, result in zip(datasets, data)}
    return data, None

def update_datasets_with_regex(regex: str, selected_runs: List[str], all_runs: List[str]):
    if not regex:
        return []
    new_dsts = {run for run in all_runs if re.search(regex, run)}
    if not new_dsts:
        return selected_runs
    dst_union = new_dsts.union(selected_runs or [])
    return sorted(list(dst_union))