File size: 3,541 Bytes
fce64e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import streamlit as st
import os
from datasets import load_dataset
from bunkatopics import Bunka
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFaceHub

# Streamlit app
st.title("Bunka Map 🗺️")

# Input parameters
dataset_id = st.text_input("Dataset ID", "bunkalab/medium-sample-technology")
language = st.text_input("Language", "english")
text_field = st.text_input("Text Field", "title")
embedder_model = st.text_input("Embedder Model", "sentence-transformers/distiluse-base-multilingual-cased-v2")
sample_size = st.number_input("Sample Size", min_value=100, max_value=10000, value=1000)
n_clusters = st.number_input("Number of Clusters", min_value=5, max_value=50, value=15)
llm_model = st.text_input("LLM Model", "mistralai/Mistral-7B-Instruct-v0.1")

# Hugging Face API token input
hf_token = st.text_input("Hugging Face API Token", type="password")

if st.button("Generate Bunka Map"):
    # Load dataset and sample
    @st.cache_data
    def load_data(dataset_id, text_field, sample_size):
        dataset = load_dataset(dataset_id, streaming=True)
        docs_sample = []
        for i, example in enumerate(dataset["train"]):
            if i >= sample_size:
                break
            docs_sample.append(example[text_field])
        return docs_sample

    docs_sample = load_data(dataset_id, text_field, sample_size)

    # Initialize embedding model and Bunka
    embedding_model = HuggingFaceEmbeddings(model_name=embedder_model)
    bunka = Bunka(embedding_model=embedding_model, language=language)

    # Fit Bunka to the text data
    bunka.fit(docs_sample)

    # Generate topics
    df_topics = bunka.get_topics(n_clusters=n_clusters, name_length=5, min_count_terms=2)

    # Visualize topics
    st.plotly_chart(bunka.visualize_topics(width=800, height=800, colorscale='Portland', density=True, label_size_ratio=60, convex_hull=True))

    # Clean labels using LLM
    if hf_token:
        os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
        llm = HuggingFaceHub(repo_id=llm_model, huggingfacehub_api_token=hf_token)
        bunka.get_clean_topic_name(llm=llm, language=language)
        st.plotly_chart(bunka.visualize_topics(width=800, height=800, colorscale='Portland', density=True, label_size_ratio=60, convex_hull=True))
    else:
        st.warning("Please provide a Hugging Face API token to clean labels using LLM.")

    # Manual topic cleaning
    st.subheader("Manually Clean Topics")
    cleaned_topics = {}
    for topic, keywords in bunka.topics_.items():
        cleaned_topic = st.text_input(f"Topic {topic}", ", ".join(keywords))
        cleaned_topics[topic] = cleaned_topic.split(", ")
    
    if st.button("Update Topics"):
        bunka.topics_ = cleaned_topics
        st.plotly_chart(bunka.visualize_topics(width=800, height=800, colorscale='Portland', density=True, label_size_ratio=60, convex_hull=True))

    # Remove unwanted topics
    st.subheader("Remove Unwanted Topics")
    topics_to_remove = st.multiselect("Select topics to remove", list(bunka.topics_.keys()))
    if st.button("Remove Topics"):
        bunka.clean_data_by_topics(topics_to_remove)
        st.plotly_chart(bunka.visualize_topics(width=800, height=800, colorscale='Portland', density=True, label_size_ratio=60, convex_hull=True))

    # Save dataset
    if st.button("Save Cleaned Dataset"):
        name = dataset_id.replace('/', '_') + '_cleaned.csv'
        bunka.df_cleaned_.to_csv(name)
        st.success(f"Dataset saved as {name}")