from collections import Counter import streamlit as st import json from itertools import islice from typing import Generator from plotly import express as px from safetensors import safe_open from semantic_search import predict from sentence_transformers import SentenceTransformer import os HF_TOKEN = os.environ.get("HF_TOKEN") def chunks(data: dict, size=13) -> Generator: it = iter(data) for i in range(0, len(data), size): yield {k: data[k] for k in islice(it, size)} def get_tree_map_data( data: dict, countings_parents: dict, countings_labels: dict, root: str = " ", ) -> tuple: names: list = [""] parents: list = [root] values: list = ["0"] for group, labels in data.items(): names.append(group) parents.append(root) if group in countings_parents: values.append(str(countings_parents[group])) else: values.append("0") for label in labels: if "-" in label: label = label.split("-") label = label[0] + "
-" + label[1] names.append(label) parents.append(group) if label in countings_labels: values.append(str(countings_labels[label])) else: values.append("0") # if "-" in label: # names.append(label.split("-")[0]) # parents.append(label) # names.append(label.split("-")[1]) # parents.append(label) return parents, names, values def load_json(path: str) -> dict: with open(path, "r") as fp: return json.load(fp) # Load Data data = load_json("data.json") taxonomy = load_json("taxonomy_processed_v3.json") taxonomy_labels = [el["group"] + " - " + el["label"] for el in taxonomy] theme_counts = dict(Counter([el["THEMA"] for el in data])) labels_counts = dict(Counter([el["BEZEICHNUNG"] for el in data])) names = [""] parents = ["Musterdatenkatalog"] taxonomy_group_label_mapper: dict = {el["group"]: [] for el in taxonomy} for el in taxonomy: if el["group"] != "Sonstiges": taxonomy_group_label_mapper[el["group"]].append(el["label"]) else: taxonomy_group_label_mapper[el["group"]].append("Sonstiges ") parents, name, values = get_tree_map_data( data=taxonomy_group_label_mapper, countings_parents=theme_counts, countings_labels=labels_counts, root="Musterdatenkatalog", ) fig = px.treemap( names=name, parents=parents, ) fig.update_layout( margin=dict(t=50, l=25, r=25, b=25), height=1000, width=1000, template="plotly", ) tensors = {} with safe_open("corpus_embeddings.pt", framework="pt", device="cpu") as f: for k in f.keys(): tensors[k] = f.get_tensor(k) model = SentenceTransformer( model_name_or_path="and-effect/musterdatenkatalog_clf", device="cpu", use_auth_token=HF_TOKEN, ) st.set_page_config(layout="wide") st.title("Musterdatenkatalog") col1, col2, col3 = st.columns(3) col1.metric("Kommunale Datensätze", len(data)) col2.metric("Themen", len(theme_counts)) col3.metric("Bezeichnungen", len(labels_counts)) st.title("Taxonomy") st.plotly_chart(fig) st.title("Predict a Dataset") # create two columns and make left column wider # st.markdown( # """ # # """, # unsafe_allow_html=True, # ) st.markdown( """ """, unsafe_allow_html=True, ) col1, col2 = st.columns([1.2, 1]) with col2: st.subheader("Example Datasets") examples = [ "Spielplätze", "Berliner Weihnachtsmärkte 2022", "Hochschulwechslerquoten zum Masterstudium nach Bundesländern", "Umringe der Bebauungspläne von Etgert", ] for example in examples: if st.button(example): if "key" not in st.session_state: st.session_state["query"] = example with col1: if "query" not in st.session_state: query = st.text_input( "Enter dataset name", ) if "query" in st.session_state and st.session_state.query in examples: query = st.text_input("Enter dataset name", value=st.session_state.query) if "query" in st.session_state and st.session_state.query not in examples: del st.session_state["query"] query = st.text_input("Enter dataset name") top_k = st.select_slider("Top Results", options=[1, 2, 3, 4, 5], value=1) predictions = predict( query=query, corpus_embeddings=tensors["corpus_embeddings"], corpus_labels=taxonomy_labels, top_k=top_k, model=model, ) if st.button("Predict"): for prediction in predictions: st.write(prediction)