Spaces:
Runtime error
Runtime error
import dash | |
import plotly.express as px | |
from dash import dcc, html | |
from dash.dependencies import Input, Output | |
from dash.exceptions import PreventUpdate | |
from datasets import load_dataset | |
# Create dash app | |
app = dash.Dash(__name__) | |
def get_dataset(name, n_items=1000): | |
ola_path = f"ola13/small-{name}-dedup" | |
dataset = load_dataset(ola_path, split="train").shuffle().select(range(n_items)).to_pandas() | |
dataset["text_length"] = dataset.apply(lambda doc: len(doc["text"]), axis=1) | |
for column in dataset.columns: | |
if column not in ["text", "perplexity", "text_length"]: | |
dataset = dataset.drop(column, axis=1) | |
dataset = dataset.sort_values("perplexity") | |
max_perp = dataset["perplexity"].max() | |
return dataset, max_perp | |
# names = ["oscar", "the_pile", "c4", "roots_en"] | |
name = "c4" | |
df, max_perplexity = get_dataset(name) | |
# Create scatter plot with x and y coordinates | |
fig = px.scatter(df, x="perplexity", y="text_length", custom_data=["text"]) | |
# Update layout and update traces | |
fig.update_layout(clickmode='event+select') | |
fig.update_traces(marker_size=3) | |
fig.update_xaxes(title_text="Perplexity (log scale)", type="log") | |
fig.update_yaxes(title_text="Text Length (log scale)", type="log") | |
styles = { | |
'textbox': { | |
'border': 'thin lightgrey solid', | |
'overflowX': 'scroll', | |
"whiteSpace": "pre-wrap;" | |
} | |
} | |
# Create app layout to show dash graph | |
app.layout = html.Div( | |
[ | |
dcc.Graph( | |
id="graph_interaction", | |
figure=fig, | |
), | |
html.Div(id='text', style=styles['textbox']) | |
] | |
) | |
# html callback function to hover the data on specific coordinates | |
def open_url(hoverData): | |
if hoverData: | |
return hoverData["points"][0]["customdata"][0] | |
else: | |
raise PreventUpdate | |
if __name__ == '__main__': | |
app.run_server(port=7860, host="0.0.0.0", debug=True) | |