# query를 자동으로 읽고 쓰는 container를 정의 from __future__ import annotations import re from typing import Callable, TypeVar import streamlit as st __all__ = ["QueryWrapper", "get_base_url"] T = TypeVar("T") import hashlib import urllib.parse def SHA1(msg: str) -> str: return hashlib.sha1(msg.encode()).hexdigest()[:8] def get_base_url(): session = st.runtime.get_instance()._session_mgr.list_active_sessions()[0] return urllib.parse.urlunparse( [session.client.request.protocol, session.client.request.host, "", "", "", ""] ) class QueryWrapper: queries: dict[str, _QueryWrapper] = {} # 기록용 def __init__(self, query: str, label: str | None = None, use_hash: bool = True): self.__wrapper = QueryWrapper.queries[query] = _QueryWrapper( query, label, use_hash ) def __call__(self, *args, **kwargs): return self.__wrapper(*args, **kwargs) @classmethod def get_sharable_link(cls): # for k, v in cls.queries.items(): # print(f"{k}: {v}") return re.sub( "&+", "&", "&".join([str(v) for k, v in cls.queries.items()]) ).strip("&") class _QueryWrapper: ILLEGAL_CHARS = "&/=?" def __init__(self, query: str, label: str | None = None, use_hash: bool = True): self.query = query self.label = label or query self.use_hash = use_hash self.hash_table = {} self.key = None def __call__( self, base_container: Callable, legal_list: list[T], default: T | list[T] | None = None, *, key: str | None = None, **kwargs, ) -> T | list[T] | None: val_from_query = st.query_params.get_all(self.query.lower()) # print(val_from_query) legal = len(val_from_query) > 0 self.key = key or self.label self.hash_table = {SHA1(str(v)): v for v in legal_list} # filter out illegal values if legal and legal_list: val_from_query = [v for v in val_from_query if v in self.hash_table] # print(self.label, val_from_query, legal) if legal: selected = [self.hash_table[v] for v in val_from_query] elif default: selected = default elif self.label in st.session_state: selected = st.session_state[self.label] if legal_list: if isinstance(selected, list): selected = [v for v in selected if v in legal_list] elif selected not in legal_list: selected = [] else: selected = [] if selected is None: pass elif len(selected) == 1 and base_container in [st.selectbox, st.radio]: selected = selected[0] # print(self.label, selected) if base_container == st.checkbox: selected = base_container( self.label, legal_list, index=legal_list.index(selected) if selected in legal_list else None, key=self.key, **kwargs, ) elif base_container == st.multiselect: selected = base_container( self.label, legal_list, default=selected, key=self.key, **kwargs ) elif base_container == st.radio: selected = base_container( self.label, legal_list, index=legal_list.index(selected) if selected in legal_list else None, key=self.key, **kwargs, ) elif base_container == st.selectbox: selected = base_container( self.label, legal_list, index=legal_list.index(selected) if selected in legal_list else None, key=self.key, **kwargs, ) else: selected = base_container(self.label, legal_list, key=self.key, **kwargs) return st.session_state[self.key] def __str__(self): selected = st.session_state.get(self.key, None) if isinstance(selected, str): return f"{self.query.lower()}={SHA1(selected)}" elif isinstance(selected, list): return "&".join([f"{self.query.lower()}={SHA1(str(v))}" for v in selected]) else: return ""