teven's picture
plotting perplexity vs text length with text hover
a9cc2b2
raw
history blame
1.99 kB
import dash
import plotly.express as px
from dash import dcc, html
from dash.dependencies import Input, Output
from dash.exceptions import PreventUpdate
from datasets import load_dataset
# Create dash app
app = dash.Dash(__name__)
def get_dataset(name, n_items=1000):
ola_path = f"ola13/small-{name}-dedup"
dataset = load_dataset(ola_path, split="train").shuffle().select(range(n_items)).to_pandas()
dataset["text_length"] = dataset.apply(lambda doc: len(doc["text"]), axis=1)
for column in dataset.columns:
if column not in ["text", "perplexity", "text_length"]:
dataset = dataset.drop(column, axis=1)
dataset = dataset.sort_values("perplexity")
max_perp = dataset["perplexity"].max()
return dataset, max_perp
# names = ["oscar", "the_pile", "c4", "roots_en"]
name = "oscar"
df, max_perplexity = get_dataset(name)
# Create scatter plot with x and y coordinates
fig = px.scatter(df, x="perplexity", y="text_length", custom_data=["text"])
# Update layout and update traces
fig.update_layout(clickmode='event+select')
fig.update_traces(marker_size=3)
fig.update_xaxes(title_text="Perplexity (log scale)", type="log")
fig.update_yaxes(title_text="Text Length (log scale)", type="log")
styles = {
'textbox': {
'border': 'thin lightgrey solid',
'overflowX': 'scroll',
"whiteSpace": "pre-wrap;"
}
}
# Create app layout to show dash graph
app.layout = html.Div(
[
dcc.Graph(
id="graph_interaction",
figure=fig,
),
html.Div(id='text', style=styles['textbox'])
]
)
# html callback function to hover the data on specific coordinates
@app.callback(
Output('text', 'children'),
Input('graph_interaction', 'hoverData'))
def open_url(hoverData):
if hoverData:
return hoverData["points"][0]["customdata"][0]
else:
raise PreventUpdate
if __name__ == '__main__':
app.run_server(port=7860, host="0.0.0.0", debug=True)