Spaces:
Runtime error
Runtime error
import gradio as gr | |
from bertopic import BERTopic | |
from datasets import load_dataset | |
from functools import lru_cache | |
def prep_dataset(): | |
dataset = load_dataset("OpenAssistant/oasst1", split="train") | |
assistant_ds = dataset.filter(lambda x: x["role"] == "assistant") | |
assistant_ds_en = assistant_ds.filter(lambda x: x["lang"] == "en") | |
return assistant_ds_en["text"] | |
topic_model = BERTopic.load("davanstrien/chat_topics") | |
fig = topic_model.visualize_topics() | |
def plot_docs(): | |
docs = prep_dataset() | |
return topic_model.visualize_documents(docs) | |
def search_topic(text): | |
similar_topics, _ = topic_model.find_topics(text, top_n=5) | |
topic_info = topic_model.get_topic_info() | |
return topic_info[topic_info["Topic"].isin(similar_topics)] | |
def plot_topic_words(num_topics=9, n_words=5): | |
return topic_model.visualize_barchart(top_n_topics=num_topics, n_words=n_words) | |
with gr.Blocks() as demo: | |
with gr.Tab("Topic words"): | |
topic_number = gr.Slider( | |
minimum=3, maximum=20, value=9, step=1, label="Number of topics" | |
) | |
plot = gr.Plot(plot_topic_words()) | |
topic_number.change(plot_topic_words, [topic_number], plot) | |
with gr.Tab("Topic search"): | |
text = gr.Textbox(lines=1, label="Search text") | |
df = gr.DataFrame() | |
text.change(search_topic, [text], df) | |
with gr.Tab("Topic distribution"): | |
gr.Plot(fig) | |
# with gr.Tab("Doc visualization"): | |
# gr.Plot(plot_docs()) | |
demo.launch(debug=True) | |