Roman Castagné
Initial commit
54c73d3
import json
from Levenshtein import distance
import streamlit as st
import numpy as np
import plotly.express as px
from sklearn.decomposition import PCA
def load_data():
embeddings = np.load("data/simplesegmentT5_embeddings.npy")
words = json.load(open("data/words.json", "r"))
return embeddings, words
def project_embeddings(embeddings):
pca = PCA(n_components=3)
proj = pca.fit_transform(embeddings)
return proj
def filter_words(words, remove_capitalized, length):
idx = []
for i, w in enumerate(words):
if remove_capitalized and w.lower() != w:
continue
if len(w) < length[0] or len(w) > length[1]:
continue
idx.append(i)
return idx
def color_length(words):
return [len(w) for w in words]
def color_first_letter(words):
return [min(1, max(0, (ord(w.lower()[0]) - 97) / 26)) for w in words]
def color_levenshtein(words):
return [distance(w, words[4]) for w in words]
def plot_scatter(words, embeddings, remove_capitalized, length, color_select):
idx = filter_words(words, remove_capitalized, length)
filtered_embeddings = embeddings[idx]
filtered_words = [words[i] for i in idx]
proj = project_embeddings(filtered_embeddings)
if color_select == "Word length":
color = color_length(filtered_words)
else:
color = color_levenshtein(filtered_words)
fig = px.scatter_3d(
x=proj[:, 0],
y=proj[:, 1],
z=proj[:, 2],
width=800,
height=600,
color=color,
color_continuous_scale=px.colors.sequential.Viridis,
hover_name=filtered_words,
title="SimpleSegmentT5 Embeddings",
)
fig.update_traces(
marker={"size": 6, "line": {"width": 2}},
selector={"mode": "markers"},
)
return fig
def main():
embeddings, words = load_data()
proj = project_embeddings(embeddings)
fig = px.scatter_3d(
x=proj[:, 0],
y=proj[:, 1],
z=proj[:, 2],
color=[len(w) for w in words],
hover_name=words,
title="SimpleSegmentT5 Embeddings",
)
st.sidebar.title("Settings")
remove_checkbox = st.sidebar.checkbox(
"Remove capitalized words",
value=True,
key="include_capitalized",
)
length_slider = st.sidebar.slider("Word length", 3, 9, (3, 9))
color_select = st.sidebar.radio("Color by", ["Word length", "Levenshtein distance to random word"])
scatter = st.plotly_chart(plot_scatter(words, embeddings, remove_checkbox, length_slider, color_select))
if __name__ == "__main__":
main()