File size: 5,043 Bytes
1f30dbc
 
 
 
 
 
 
 
a86046b
1f30dbc
 
 
 
a86046b
 
 
1f30dbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a86046b
1f30dbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a86046b
 
 
 
 
 
 
 
 
 
 
 
1f30dbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a86046b
 
abf62cb
1f30dbc
a86046b
 
 
 
 
 
1f30dbc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import logging
from functools import partial
from typing import Callable, Optional

import pandas as pd
import streamlit as st
from bokeh.plotting import Figure
from embedding_lenses.data import uploaded_file_to_dataframe
from embedding_lenses.dimensionality_reduction import get_tsne_embeddings, get_umap_embeddings
from embedding_lenses.embedding import embed_text, load_model
from embedding_lenses.utils import encode_labels
from sentence_transformers import SentenceTransformer

from perplexity_lenses.data import documents_df_to_sentences_df, hub_dataset_to_dataframe
from perplexity_lenses.perplexity import KenlmModel
from perplexity_lenses.visualization import draw_interactive_scatter_plot

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
EMBEDDING_MODELS = ["distiluse-base-multilingual-cased-v1", "all-mpnet-base-v2", "flax-sentence-embeddings/all_datasets_v3_mpnet-base"]
DIMENSIONALITY_REDUCTION_ALGORITHMS = ["UMAP", "t-SNE"]
LANGUAGES = [
    "af",
    "ar",
    "az",
    "be",
    "bg",
    "bn",
    "ca",
    "cs",
    "da",
    "de",
    "el",
    "en",
    "es",
    "et",
    "fa",
    "fi",
    "fr",
    "gu",
    "he",
    "hi",
    "hr",
    "hu",
    "hy",
    "id",
    "is",
    "it",
    "ja",
    "ka",
    "kk",
    "km",
    "kn",
    "ko",
    "lt",
    "lv",
    "mk",
    "ml",
    "mn",
    "mr",
    "my",
    "ne",
    "nl",
    "no",
    "pl",
    "pt",
    "ro",
    "ru",
    "uk",
    "zh",
]
DOCUMENT_TYPES = ["Whole document", "Sentence"]
SEED = 0


def generate_plot(
    df: pd.DataFrame,
    text_column: str,
    label_column: str,
    sample: Optional[int],
    dimensionality_reduction_function: Callable,
    model: SentenceTransformer,
) -> Figure:
    if text_column not in df.columns:
        raise ValueError(f"The specified column name doesn't exist. Columns available: {df.columns.values}")
    if label_column not in df.columns:
        df[label_column] = 0
    df = df.dropna(subset=[text_column, label_column])
    if sample:
        df = df.sample(min(sample, df.shape[0]), random_state=SEED)
    with st.spinner(text="Embedding text..."):
        embeddings = embed_text(df[text_column].values.tolist(), model)
    logger.info("Encoding labels")
    encoded_labels = encode_labels(df[label_column])
    with st.spinner("Reducing dimensionality..."):
        embeddings_2d = dimensionality_reduction_function(embeddings)
    logger.info("Generating figure")
    plot = draw_interactive_scatter_plot(
        df[text_column].values, embeddings_2d[:, 0], embeddings_2d[:, 1], encoded_labels.values, df[label_column].values, text_column, label_column
    )
    return plot


st.title("Perplexity Lenses")
st.write("Visualize text embeddings in 2D using colors to represent perplexity values.")
uploaded_file = st.file_uploader("Choose an csv/tsv file...", type=["csv", "tsv"])
st.write("Alternatively, select a dataset from the [hub](https://huggingface.co/datasets)")
col1, col2, col3 = st.columns(3)
with col1:
    hub_dataset = st.text_input("Dataset name", "mc4")
with col2:
    hub_dataset_config = st.text_input("Dataset configuration", "es")
with col3:
    hub_dataset_split = st.text_input("Dataset split", "train")

col4, col5 = st.columns(2)
with col4:
    text_column = st.text_input("Text field name", "text")
with col5:
    language = st.selectbox("Language", LANGUAGES, 12)

col6, col7 = st.columns(2)
with col6:
    doc_type = st.selectbox("Document type", DOCUMENT_TYPES, 1)
with col7:
    sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000)

dimensionality_reduction = st.selectbox("Dimensionality Reduction algorithm", DIMENSIONALITY_REDUCTION_ALGORITHMS, 0)
model_name = st.selectbox("Sentence embedding model", EMBEDDING_MODELS, 0)

with st.spinner(text="Loading embedding model..."):
    model = load_model(model_name)
dimensionality_reduction_function = (
    partial(get_umap_embeddings, random_state=SEED) if dimensionality_reduction == "UMAP" else partial(get_tsne_embeddings, random_state=SEED)
)

with st.spinner(text="Loading KenLM model..."):
    kenlm_model = KenlmModel.from_pretrained(language)

if uploaded_file or hub_dataset:
    with st.spinner("Loading dataset..."):
        if uploaded_file:
            df = uploaded_file_to_dataframe(uploaded_file)
            if doc_type == "Sentence":
                df = documents_df_to_sentences_df(df, text_column, sample, seed=SEED)
            df["perplexity"] = df[text_column].map(kenlm_model.get_perplexity)
        else:
            df = hub_dataset_to_dataframe(hub_dataset, hub_dataset_config, hub_dataset_split, sample, text_column, kenlm_model, seed=SEED, doc_type=doc_type)

    # Round perplexity
    df["perplexity"] = df["perplexity"].round().astype(int)
    logger.info(f"Perplexity range: {df['perplexity'].min()} - {df['perplexity'].max()}")
    plot = generate_plot(df, text_column, "perplexity", None, dimensionality_reduction_function, model)
    logger.info("Displaying plot")
    st.bokeh_chart(plot)
    logger.info("Done")