Spaces:
Runtime error
Runtime error
import logging | |
from functools import partial | |
from typing import Callable, Optional | |
import pandas as pd | |
import streamlit as st | |
from bokeh.plotting import Figure | |
from embedding_lenses.data import uploaded_file_to_dataframe | |
from embedding_lenses.dimensionality_reduction import (get_tsne_embeddings, | |
get_umap_embeddings) | |
from embedding_lenses.embedding import embed_text, load_model | |
from embedding_lenses.utils import encode_labels | |
from embedding_lenses.visualization import draw_interactive_scatter_plot | |
from sentence_transformers import SentenceTransformer | |
from data import hub_dataset_to_dataframe | |
from perplexity import KenlmModel | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
EMBEDDING_MODELS = ["distiluse-base-multilingual-cased-v1", "all-mpnet-base-v2", "flax-sentence-embeddings/all_datasets_v3_mpnet-base"] | |
DIMENSIONALITY_REDUCTION_ALGORITHMS = ["UMAP", "t-SNE"] | |
LANGUAGES = [ | |
"af", | |
"ar", | |
"az", | |
"be", | |
"bg", | |
"bn", | |
"ca", | |
"cs", | |
"da", | |
"de", | |
"el", | |
"en", | |
"es", | |
"et", | |
"fa", | |
"fi", | |
"fr", | |
"gu", | |
"he", | |
"hi", | |
"hr", | |
"hu", | |
"hy", | |
"id", | |
"is", | |
"it", | |
"ja", | |
"ka", | |
"kk", | |
"km", | |
"kn", | |
"ko", | |
"lt", | |
"lv", | |
"mk", | |
"ml", | |
"mn", | |
"mr", | |
"my", | |
"ne", | |
"nl", | |
"no", | |
"pl", | |
"pt", | |
"ro", | |
"ru", | |
"uk", | |
"zh", | |
] | |
SEED = 0 | |
def generate_plot( | |
df: pd.DataFrame, | |
text_column: str, | |
label_column: str, | |
sample: Optional[int], | |
dimensionality_reduction_function: Callable, | |
model: SentenceTransformer, | |
) -> Figure: | |
if text_column not in df.columns: | |
raise ValueError(f"The specified column name doesn't exist. Columns available: {df.columns.values}") | |
if label_column not in df.columns: | |
df[label_column] = 0 | |
df = df.dropna(subset=[text_column, label_column]) | |
if sample: | |
df = df.sample(min(sample, df.shape[0]), random_state=SEED) | |
with st.spinner(text="Embedding text..."): | |
embeddings = embed_text(df[text_column].values.tolist(), model) | |
logger.info("Encoding labels") | |
encoded_labels = encode_labels(df[label_column]) | |
with st.spinner("Reducing dimensionality..."): | |
embeddings_2d = dimensionality_reduction_function(embeddings) | |
logger.info("Generating figure") | |
plot = draw_interactive_scatter_plot( | |
df[text_column].values, embeddings_2d[:, 0], embeddings_2d[:, 1], encoded_labels.values, df[label_column].values, text_column, label_column | |
) | |
return plot | |
st.title("Perplexity Lenses") | |
st.write("Visualize text embeddings in 2D using colors to represent perplexity values.") | |
uploaded_file = st.file_uploader("Choose an csv/tsv file...", type=["csv", "tsv"]) | |
st.write("Alternatively, select a dataset from the [hub](https://huggingface.co/datasets)") | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
hub_dataset = st.text_input("Dataset name", "mc4") | |
with col2: | |
hub_dataset_config = st.text_input("Dataset configuration", "es") | |
with col3: | |
hub_dataset_split = st.text_input("Dataset split", "train") | |
text_column = st.text_input("Text column name", "text") | |
language = st.selectbox("Language", LANGUAGES, 12) | |
sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000) | |
dimensionality_reduction = st.selectbox("Dimensionality Reduction algorithm", DIMENSIONALITY_REDUCTION_ALGORITHMS, 0) | |
model_name = st.selectbox("Sentence embedding model", EMBEDDING_MODELS, 0) | |
with st.spinner(text="Loading embedding model..."): | |
model = load_model(model_name) | |
dimensionality_reduction_function = ( | |
partial(get_umap_embeddings, random_state=SEED) if dimensionality_reduction == "UMAP" else partial(get_tsne_embeddings, random_state=SEED) | |
) | |
with st.spinner(text="Loading KenLM model..."): | |
kenlm_model = KenlmModel.from_pretrained(language) | |
if uploaded_file or hub_dataset: | |
with st.spinner("Loading dataset..."): | |
if uploaded_file: | |
df = uploaded_file_to_dataframe(uploaded_file) | |
df["perplexity"] = df[text_column].map(lambda x: model.get_perplexity(x[text_column])) | |
else: | |
df = hub_dataset_to_dataframe(hub_dataset, hub_dataset_config, hub_dataset_split, sample, text_column, kenlm_model, seed=SEED) | |
plot = generate_plot(df, text_column, "perplexity", sample, dimensionality_reduction_function, model) | |
logger.info("Displaying plot") | |
st.bokeh_chart(plot) | |
logger.info("Done") | |