Spaces:
Running
Running
# query를 자동으로 읽고 쓰는 container를 정의 | |
from __future__ import annotations | |
import re | |
from typing import Callable, TypeVar | |
import streamlit as st | |
__all__ = ["QueryWrapper", "get_base_url"] | |
T = TypeVar("T") | |
import hashlib | |
import urllib.parse | |
def SHA1(msg: str) -> str: | |
return hashlib.sha1(msg.encode()).hexdigest()[:8] | |
def get_base_url(): | |
session = st.runtime.get_instance()._session_mgr.list_active_sessions()[0] | |
return urllib.parse.urlunparse( | |
[session.client.request.protocol, session.client.request.host, "", "", "", ""] | |
) | |
class QueryWrapper: | |
queries: dict[str, _QueryWrapper] = {} # 기록용 | |
def __init__(self, query: str, label: str | None = None, use_hash: bool = True): | |
self.__wrapper = QueryWrapper.queries[query] = _QueryWrapper( | |
query, label, use_hash | |
) | |
def __call__(self, *args, **kwargs): | |
return self.__wrapper(*args, **kwargs) | |
def get_sharable_link(cls): | |
# for k, v in cls.queries.items(): | |
# print(f"{k}: {v}") | |
return re.sub( | |
"&+", "&", "&".join([str(v) for k, v in cls.queries.items()]) | |
).strip("&") | |
class _QueryWrapper: | |
ILLEGAL_CHARS = "&/=?" | |
def __init__(self, query: str, label: str | None = None, use_hash: bool = True): | |
self.query = query | |
self.label = label or query | |
self.use_hash = use_hash | |
self.hash_table = {} | |
self.key = None | |
def __call__( | |
self, | |
base_container: Callable, | |
legal_list: list[T], | |
default: T | list[T] | None = None, | |
*, | |
key: str | None = None, | |
**kwargs, | |
) -> T | list[T] | None: | |
val_from_query = st.query_params.get_all(self.query.lower()) | |
# print(val_from_query) | |
legal = len(val_from_query) > 0 | |
self.key = key or self.label | |
self.hash_table = {SHA1(str(v)): v for v in legal_list} | |
# filter out illegal values | |
if legal and legal_list: | |
val_from_query = [v for v in val_from_query if v in self.hash_table] | |
# print(self.label, val_from_query, legal) | |
if legal: | |
selected = [self.hash_table[v] for v in val_from_query] | |
elif default: | |
selected = default | |
elif self.label in st.session_state: | |
selected = st.session_state[self.label] | |
if legal_list: | |
if isinstance(selected, list): | |
selected = [v for v in selected if v in legal_list] | |
elif selected not in legal_list: | |
selected = [] | |
else: | |
selected = [] | |
if selected is None: | |
pass | |
elif len(selected) == 1 and base_container in [st.selectbox, st.radio]: | |
selected = selected[0] | |
# print(self.label, selected) | |
if base_container == st.checkbox: | |
selected = base_container( | |
self.label, | |
legal_list, | |
index=legal_list.index(selected) if selected in legal_list else None, | |
key=self.key, | |
**kwargs, | |
) | |
elif base_container == st.multiselect: | |
selected = base_container( | |
self.label, legal_list, default=selected, key=self.key, **kwargs | |
) | |
elif base_container == st.radio: | |
selected = base_container( | |
self.label, | |
legal_list, | |
index=legal_list.index(selected) if selected in legal_list else None, | |
key=self.key, | |
**kwargs, | |
) | |
elif base_container == st.selectbox: | |
selected = base_container( | |
self.label, | |
legal_list, | |
index=legal_list.index(selected) if selected in legal_list else None, | |
key=self.key, | |
**kwargs, | |
) | |
else: | |
selected = base_container(self.label, legal_list, key=self.key, **kwargs) | |
return st.session_state[self.key] | |
def __str__(self): | |
selected = st.session_state.get(self.key, None) | |
if isinstance(selected, str): | |
return f"{self.query.lower()}={SHA1(selected)}" | |
elif isinstance(selected, list): | |
return "&".join([f"{self.query.lower()}={SHA1(str(v))}" for v in selected]) | |
else: | |
return "" | |