Spaces:
Runtime error
Runtime error
File size: 3,686 Bytes
86e673e 3c30fa3 86e673e 3c30fa3 86e673e 3c30fa3 86e673e 3c30fa3 86e673e ab846df 86e673e 3c30fa3 86e673e 3c30fa3 86e673e 7b62017 86e673e 3c30fa3 86e673e 3c30fa3 7b62017 3c30fa3 86e673e 3c30fa3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import logging
import time
from typing import Callable, Optional, Tuple, Union
import pandas as pd
import streamlit as st
from bokeh.palettes import Turbo256
from bokeh.plotting import Figure
from embedding_lenses.embedding import embed_text
from embedding_lenses.utils import encode_labels
from embedding_lenses.visualization import draw_interactive_scatter_plot
from sentence_transformers import SentenceTransformer
from perplexity_lenses import REGISTRY_DATASET
logger = logging.getLogger(__name__)
EMBEDDING_MODELS = [
"distiluse-base-multilingual-cased-v1",
"distiluse-base-multilingual-cased-v2",
"all-mpnet-base-v2",
"flax-sentence-embeddings/all_datasets_v3_mpnet-base",
]
DIMENSIONALITY_REDUCTION_ALGORITHMS = ["UMAP", "t-SNE"]
DOCUMENT_TYPES = ["Whole document", "Sentence"]
SEED = 0
LANGUAGES = [
"af",
"ar",
"az",
"be",
"bg",
"bn",
"ca",
"cs",
"da",
"de",
"el",
"en",
"es",
"et",
"fa",
"fi",
"fr",
"gu",
"he",
"hi",
"hr",
"hu",
"hy",
"id",
"is",
"it",
"ja",
"ka",
"kk",
"km",
"kn",
"ko",
"lt",
"lv",
"mk",
"ml",
"mn",
"mr",
"my",
"ne",
"nl",
"no",
"pl",
"pt",
"ro",
"ru",
"uk",
"zh",
]
PERPLEXITY_MODELS = ["Wikipedia", "OSCAR"]
class ContextLogger:
def __init__(self, text: str = ""):
self.text = text
self.start_time = time.time()
def __enter__(self):
logger.info(self.text)
def __exit__(self, type, value, traceback):
logger.info(f"Took: {time.time() - self.start_time:.4f} seconds")
def generate_plot(
df: pd.DataFrame,
text_column: str,
label_column: str,
sample: Optional[int],
dimensionality_reduction_function: Callable,
model: SentenceTransformer,
seed: int = 0,
context_logger: Union[st.spinner, ContextLogger] = ContextLogger,
hub_dataset: str = "",
) -> Tuple[Figure, Optional[Figure]]:
if text_column not in df.columns:
raise ValueError(
f"The specified column name doesn't exist. Columns available: {df.columns.values}"
)
if label_column not in df.columns:
df[label_column] = 0
df = df.dropna(subset=[text_column, label_column])
if sample:
df = df.sample(min(sample, df.shape[0]), random_state=seed)
with context_logger(text="Embedding text..."):
embeddings = embed_text(df[text_column].values.tolist(), model)
logger.info("Encoding labels")
encoded_labels = encode_labels(df[label_column])
with context_logger("Reducing dimensionality..."):
embeddings_2d = dimensionality_reduction_function(embeddings)
logger.info("Generating figure")
hover_data = {
text_column: df[text_column].values,
label_column: encoded_labels.values,
}
# Round perplexity values
values = df[label_column].values.round().astype(int)
plot = draw_interactive_scatter_plot(
hover_data,
embeddings_2d[:, 0],
embeddings_2d[:, 1],
values,
)
# Special case for the registry dataset
plot_registry = None
if hub_dataset == REGISTRY_DATASET:
encoded_labels = encode_labels(df["label"])
hover_data = {
text_column: df[text_column].values,
"label": df["label"].values,
label_column: df[label_column].values,
}
plot_registry = draw_interactive_scatter_plot(
hover_data,
embeddings_2d[:, 0],
embeddings_2d[:, 1],
encoded_labels.values,
palette=Turbo256,
)
return plot, plot_registry
|