File size: 4,766 Bytes
40e38d3 75448af 40e38d3 276d919 40e38d3 75448af 40e38d3 75448af 40e38d3 75448af 40e38d3 75448af 40e38d3 75448af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
from functools import partial
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import gradio as gr
from typing import Dict, List
from src.logic.data_processing import PARTITION_OPTIONS, prepare_for_non_grouped_plotting, prepare_for_group_plotting
from src.logic.graph_settings import Grouping
from src.logic.utils import set_alpha
from datatrove.utils.stats import MetricStatsDict
def plot_scatter(
data: Dict[str, MetricStatsDict],
metric_name: str,
log_scale_x: bool,
log_scale_y: bool,
normalization: bool,
rounding: int,
cumsum: bool,
perc: bool,
progress: gr.Progress,
):
fig = go.Figure()
data = {name: histogram for name, histogram in sorted(data.items())}
for i, (name, histogram) in enumerate(progress.tqdm(data.items(), total=len(data), desc="Plotting...")):
histogram_prepared = prepare_for_non_grouped_plotting(histogram, normalization, rounding)
x = sorted(histogram_prepared.keys())
y = [histogram_prepared[k] for k in x]
if cumsum:
y = np.cumsum(y).tolist()
if perc:
y = (np.array(y) * 100).tolist()
fig.add_trace(
go.Scatter(
x=x,
y=y,
mode="lines",
name=name,
marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)),
)
)
yaxis_title = "Frequency" if normalization else "Total"
fig.update_layout(
title=f"Line Plots for {metric_name}",
xaxis_title=metric_name,
yaxis_title=yaxis_title,
xaxis_type="log" if log_scale_x and len(x) > 1 else None,
yaxis_type="log" if log_scale_y and len(y) > 1 else None,
width=1200,
height=600,
showlegend=True,
)
return fig
def plot_bars(
data: Dict[str, MetricStatsDict],
metric_name: str,
top_k: int,
direction: PARTITION_OPTIONS,
regex: str | None,
rounding: int,
log_scale_x: bool,
log_scale_y: bool,
show_stds: bool,
progress: gr.Progress,
):
fig = go.Figure()
x = []
y = []
for i, (name, histogram) in enumerate(progress.tqdm(data.items(), total=len(data), desc="Plotting...")):
x, y, stds = prepare_for_group_plotting(histogram, top_k, direction, regex, rounding)
fig.add_trace(go.Bar(
x=x,
y=y,
name=f"{name} Mean",
marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)),
error_y=dict(type='data', array=stds, visible=show_stds)
))
fig.update_layout(
title=f"Bar Plots for {metric_name}",
xaxis_title=metric_name,
yaxis_title="Avg. value",
xaxis_type="log" if log_scale_x and len(x) > 1 else None,
yaxis_type="log" if log_scale_y and len(y) > 1 else None,
autosize=True,
width=1200,
height=600,
showlegend=True,
)
return fig
# Add any other necessary functions
def plot_data(
metric_data: Dict[str, MetricStatsDict],
metric_name: str,
normalize: bool,
rounding: int,
grouping: Grouping,
top_n: int,
direction: PARTITION_OPTIONS,
group_regex: str,
log_scale_x: bool,
log_scale_y: bool,
cdf: bool,
perc: bool,
show_stds: bool,
) -> tuple[go.Figure, gr.Row, str]:
if grouping == "histogram":
fig = plot_scatter(
metric_data,
metric_name,
log_scale_x,
log_scale_y,
normalize,
rounding,
cdf,
perc,
gr.Progress(),
)
min_max_hist_data = generate_min_max_hist_data(metric_data)
return fig, gr.Row.update(visible=True), min_max_hist_data
else:
fig = plot_bars(
metric_data,
metric_name,
top_n,
direction,
group_regex,
rounding,
log_scale_x,
log_scale_y,
show_stds,
gr.Progress(),
)
return fig, gr.Row.update(visible=True), ""
def generate_min_max_hist_data(data: Dict[str, MetricStatsDict]) -> str:
runs_data = {
run: {
"min": min(map(float, dato.keys())),
"max": max(map(float, dato.keys())),
}
for run, dato in data.items()
}
runs_rows = [
f"| {run} | {values['min']:.4f} | {values['max']:.4f} |"
for run, values in runs_data.items()
]
header = "| Run | Min | Max |\n|-----|-----|-----|\n"
return header + "\n".join(runs_rows) |