from collections import Counter
import streamlit as st
import json
from itertools import islice
from typing import Generator
from plotly import express as px
from safetensors import safe_open
from semantic_search import predict
from sentence_transformers import SentenceTransformer
import os
HF_TOKEN = os.environ.get("HF_TOKEN")
def chunks(data: dict, size=13) -> Generator:
it = iter(data)
for i in range(0, len(data), size):
yield {k: data[k] for k in islice(it, size)}
def get_tree_map_data(
data: dict,
countings_parents: dict,
countings_labels: dict,
root: str = " ",
) -> tuple:
names: list = [""]
parents: list = [root]
values: list = ["0"]
for group, labels in data.items():
names.append(group)
parents.append(root)
if group in countings_parents:
values.append(str(countings_parents[group]))
else:
values.append("0")
for label in labels:
if "-" in label:
label = label.split("-")
label = label[0] + "
-" + label[1]
names.append(label)
parents.append(group)
if label in countings_labels:
values.append(str(countings_labels[label]))
else:
values.append("0")
# if "-" in label:
# names.append(label.split("-")[0])
# parents.append(label)
# names.append(label.split("-")[1])
# parents.append(label)
return parents, names, values
def load_json(path: str) -> dict:
with open(path, "r") as fp:
return json.load(fp)
# Load Data
data = load_json("data.json")
taxonomy = load_json("taxonomy_processed_v3.json")
taxonomy_labels = [el["group"] + " - " + el["label"] for el in taxonomy]
theme_counts = dict(Counter([el["THEMA"] for el in data]))
labels_counts = dict(Counter([el["BEZEICHNUNG"] for el in data]))
names = [""]
parents = ["Musterdatenkatalog"]
taxonomy_group_label_mapper: dict = {el["group"]: [] for el in taxonomy}
for el in taxonomy:
if el["group"] != "Sonstiges":
taxonomy_group_label_mapper[el["group"]].append(el["label"])
else:
taxonomy_group_label_mapper[el["group"]].append("Sonstiges ")
parents, name, values = get_tree_map_data(
data=taxonomy_group_label_mapper,
countings_parents=theme_counts,
countings_labels=labels_counts,
root="Musterdatenkatalog",
)
fig = px.treemap(
names=name,
parents=parents,
)
fig.update_layout(
margin=dict(t=50, l=25, r=25, b=25),
height=1000,
width=1000,
template="plotly",
)
tensors = {}
with safe_open("corpus_embeddings.pt", framework="pt", device="cpu") as f:
for k in f.keys():
tensors[k] = f.get_tensor(k)
model = SentenceTransformer(
model_name_or_path="and-effect/musterdatenkatalog_clf",
device="cpu",
use_auth_token=HF_TOKEN,
)
st.set_page_config(layout="wide")
st.title("Musterdatenkatalog")
col1, col2, col3 = st.columns(3)
col1.metric("Kommunale Datensätze", len(data))
col2.metric("Themen", len(theme_counts))
col3.metric("Bezeichnungen", len(labels_counts))
st.title("Taxonomy")
st.plotly_chart(fig)
st.title("Predict a Dataset")
# create two columns and make left column wider
# st.markdown(
# """
#
# """,
# unsafe_allow_html=True,
# )
st.markdown(
"""
""",
unsafe_allow_html=True,
)
col1, col2 = st.columns([1.2, 1])
with col2:
st.subheader("Example Datasets")
examples = [
"Spielplätze",
"Berliner Weihnachtsmärkte 2022",
"Hochschulwechslerquoten zum Masterstudium nach Bundesländern",
"Umringe der Bebauungspläne von Etgert",
]
for example in examples:
if st.button(example):
if "key" not in st.session_state:
st.session_state["query"] = example
with col1:
if "query" not in st.session_state:
query = st.text_input(
"Enter dataset name",
)
if "query" in st.session_state and st.session_state.query in examples:
query = st.text_input("Enter dataset name", value=st.session_state.query)
if "query" in st.session_state and st.session_state.query not in examples:
del st.session_state["query"]
query = st.text_input("Enter dataset name")
top_k = st.select_slider("Top Results", options=[1, 2, 3, 4, 5], value=1)
predictions = predict(
query=query,
corpus_embeddings=tensors["corpus_embeddings"],
corpus_labels=taxonomy_labels,
top_k=top_k,
model=model,
)
if st.button("Predict"):
for prediction in predictions:
st.write(prediction)