textographe / app.py
gloignon's picture
Create app.py
f4d46f6 verified
raw
history blame
1.92 kB
import gradio as gr
import numpy as np
import pandas as pd
import plotly.express as px
from sklearn.decomposition import PCA
from sentence_transformers import SentenceTransformer
# Load pre-trained sentence transformer model
model = SentenceTransformer('all-MiniLM-L6-v2')
# Function to calculate embeddings and PCA
def compute_pca(texts):
# Generate embeddings
embeddings = model.encode(texts)
# Compute PCA
pca = PCA(n_components=2)
pca_result = pca.fit_transform(embeddings)
# Create DataFrame for visualization
df = pd.DataFrame({
'Text': texts,
'PC1': pca_result[:, 0],
'PC2': pca_result[:, 1]
})
# Plot the PCA result
fig = px.scatter(df, x='PC1', y='PC2', text='Text', title='PCA of Text Embeddings')
return fig
# Define Gradio app layout and interactions
def text_editor_app():
with gr.Blocks() as demo:
# Text box to input texts
text_input = gr.Textbox(lines=10, placeholder="Enter or paste your texts here, one per line...", label="Text Inputs")
# Display the list of texts
texts = gr.Dataframe(headers=["Texts"], label="Text List", interactive=True)
# Button to process texts
submit_button = gr.Button("Compute Embeddings and PCA")
# Output plot
output_plot = gr.Plot(label="PCA Visualization")
# Define button click interaction
def process_texts(text_input):
# Split input texts by newline
text_list = text_input.strip().split('\n')
return gr.DataFrame.update(value=[[t] for t in text_list], row_count=len(text_list))
submit_button.click(fn=lambda x: compute_pca([t[0] for t in x]), inputs=texts, outputs=output_plot)
text_input.change(fn=process_texts, inputs=text_input, outputs=texts)
return demo
# Launch the app
text_editor_app().launch()