davanstrien's picture
davanstrien HF staff
app
b99b870
raw
history blame
1.51 kB
import gradio as gr
from bertopic import BERTopic
from datasets import load_dataset
from functools import lru_cache
def prep_dataset():
dataset = load_dataset("OpenAssistant/oasst1", split="train")
assistant_ds = dataset.filter(lambda x: x["role"] == "assistant")
assistant_ds_en = assistant_ds.filter(lambda x: x["lang"] == "en")
return assistant_ds_en["text"]
topic_model = BERTopic.load("davanstrien/chat_topics")
fig = topic_model.visualize_topics()
def plot_docs():
docs = prep_dataset()
return topic_model.visualize_documents(docs)
def search_topic(text):
similar_topics, _ = topic_model.find_topics(text, top_n=5)
topic_info = topic_model.get_topic_info()
return topic_info[topic_info["Topic"].isin(similar_topics)]
def plot_topic_words(num_topics=9, n_words=5):
return topic_model.visualize_barchart(top_n_topics=num_topics, n_words=n_words)
with gr.Blocks() as demo:
with gr.Tab("Topic words"):
topic_number = gr.Slider(
minimum=3, maximum=20, value=9, step=1, label="Number of topics"
)
plot = gr.Plot(plot_topic_words())
topic_number.change(plot_topic_words, [topic_number], plot)
with gr.Tab("Topic search"):
text = gr.Textbox(lines=1, label="Search text")
df = gr.DataFrame()
text.change(search_topic, [text], df)
with gr.Tab("Topic distribution"):
gr.Plot(fig)
# with gr.Tab("Doc visualization"):
# gr.Plot(plot_docs())
demo.launch(debug=True)