Spaces:
Sleeping
Sleeping
from collections import Counter | |
import streamlit as st | |
import json | |
from itertools import islice | |
from typing import Generator | |
from plotly import express as px | |
from safetensors import safe_open | |
from semantic_search import predict | |
from sentence_transformers import SentenceTransformer | |
import os | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
def chunks(data: dict, size=13) -> Generator: | |
it = iter(data) | |
for i in range(0, len(data), size): | |
yield {k: data[k] for k in islice(it, size)} | |
def get_tree_map_data( | |
data: dict, | |
countings_parents: dict, | |
countings_labels: dict, | |
root: str = " ", | |
) -> tuple: | |
names: list = [""] | |
parents: list = [root] | |
values: list = ["0"] | |
for group, labels in data.items(): | |
names.append(group) | |
parents.append(root) | |
if group in countings_parents: | |
values.append(str(countings_parents[group])) | |
else: | |
values.append("0") | |
for label in labels: | |
if "-" in label: | |
label = label.split("-") | |
label = label[0] + "<br> -" + label[1] | |
names.append(label) | |
parents.append(group) | |
if label in countings_labels: | |
values.append(str(countings_labels[label])) | |
else: | |
values.append("0") | |
# if "-" in label: | |
# names.append(label.split("-")[0]) | |
# parents.append(label) | |
# names.append(label.split("-")[1]) | |
# parents.append(label) | |
return parents, names, values | |
def load_json(path: str) -> dict: | |
with open(path, "r") as fp: | |
return json.load(fp) | |
# Load Data | |
data = load_json("data.json") | |
taxonomy = load_json("taxonomy_processed_v3.json") | |
taxonomy_labels = [el["group"] + " - " + el["label"] for el in taxonomy] | |
theme_counts = dict(Counter([el["THEMA"] for el in data])) | |
labels_counts = dict(Counter([el["BEZEICHNUNG"] for el in data])) | |
names = [""] | |
parents = ["Musterdatenkatalog"] | |
taxonomy_group_label_mapper: dict = {el["group"]: [] for el in taxonomy} | |
for el in taxonomy: | |
if el["group"] != "Sonstiges": | |
taxonomy_group_label_mapper[el["group"]].append(el["label"]) | |
else: | |
taxonomy_group_label_mapper[el["group"]].append("Sonstiges ") | |
parents, name, values = get_tree_map_data( | |
data=taxonomy_group_label_mapper, | |
countings_parents=theme_counts, | |
countings_labels=labels_counts, | |
root="Musterdatenkatalog", | |
) | |
fig = px.treemap( | |
names=name, | |
parents=parents, | |
) | |
fig.update_layout( | |
margin=dict(t=50, l=25, r=25, b=25), | |
height=1000, | |
width=1000, | |
template="plotly", | |
) | |
tensors = {} | |
with safe_open("corpus_embeddings.pt", framework="pt", device="cpu") as f: | |
for k in f.keys(): | |
tensors[k] = f.get_tensor(k) | |
model = SentenceTransformer( | |
model_name_or_path="and-effect/musterdatenkatalog_clf", | |
device="cpu", | |
use_auth_token=HF_TOKEN, | |
) | |
st.set_page_config(layout="wide") | |
st.title("Musterdatenkatalog") | |
col1, col2, col3 = st.columns(3) | |
col1.metric("Kommunale Datensätze", len(data)) | |
col2.metric("Themen", len(theme_counts)) | |
col3.metric("Bezeichnungen", len(labels_counts)) | |
st.title("Taxonomy") | |
st.plotly_chart(fig) | |
st.title("Predict a Dataset") | |
# create two columns and make left column wider | |
# st.markdown( | |
# """ | |
# <style> | |
# div[data-testid="stVerticalBlock"] div[style*="flex-direction: column;"] div[data-testid="stVerticalBlock"] { | |
# border-radius: 15px; | |
# background-color: white; | |
# box-shadow: 0 0 10px #eee; | |
# border: 1px solid #ddd; | |
# padding: 1rem;; | |
# } | |
# </style> | |
# """, | |
# unsafe_allow_html=True, | |
# ) | |
st.markdown( | |
""" | |
<style> | |
/* Style columns */ | |
[data-testid="column"] { | |
border-radius: 15px; | |
background-color: white; | |
box-shadow: 0 0 10px #eee; | |
border: 1px solid #ddd; | |
padding: 1rem;; | |
} | |
/* Style containers */ | |
[data-testid="stVerticalBlock"] > [style*="flex-direction: column;"] > [data-testid="stVerticalBlock"] { | |
border-radius: 15px; | |
background-color: white; | |
box-shadow: 0 0 10px #eee; | |
border: 1px solid #ddd; | |
padding: 1rem;; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
col1, col2 = st.columns([1.2, 1]) | |
with col2: | |
st.subheader("Example Datasets") | |
examples = [ | |
"Spielplätze", | |
"Berliner Weihnachtsmärkte 2022", | |
"Hochschulwechslerquoten zum Masterstudium nach Bundesländern", | |
"Umringe der Bebauungspläne von Etgert", | |
] | |
for example in examples: | |
if st.button(example): | |
if "key" not in st.session_state: | |
st.session_state["query"] = example | |
with col1: | |
if "query" not in st.session_state: | |
query = st.text_input( | |
"Enter dataset name", | |
) | |
if "query" in st.session_state and st.session_state.query in examples: | |
query = st.text_input("Enter dataset name", value=st.session_state.query) | |
if "query" in st.session_state and st.session_state.query not in examples: | |
del st.session_state["query"] | |
query = st.text_input("Enter dataset name") | |
top_k = st.select_slider("Top Results", options=[1, 2, 3, 4, 5], value=1) | |
predictions = predict( | |
query=query, | |
corpus_embeddings=tensors["corpus_embeddings"], | |
corpus_labels=taxonomy_labels, | |
top_k=top_k, | |
model=model, | |
) | |
if st.button("Predict"): | |
for prediction in predictions: | |
st.write(prediction) | |