VARCO_Arena / query_comp.py
sonsus's picture
others
c2ba4d5
# query를 자동으로 읽고 쓰는 container를 정의
from __future__ import annotations
import re
from typing import Callable, TypeVar
import streamlit as st
__all__ = ["QueryWrapper", "get_base_url"]
T = TypeVar("T")
import hashlib
import urllib.parse
def SHA1(msg: str) -> str:
return hashlib.sha1(msg.encode()).hexdigest()[:8]
def get_base_url():
session = st.runtime.get_instance()._session_mgr.list_active_sessions()[0]
return urllib.parse.urlunparse(
[session.client.request.protocol, session.client.request.host, "", "", "", ""]
)
class QueryWrapper:
queries: dict[str, _QueryWrapper] = {} # 기록용
def __init__(self, query: str, label: str | None = None, use_hash: bool = True):
self.__wrapper = QueryWrapper.queries[query] = _QueryWrapper(
query, label, use_hash
)
def __call__(self, *args, **kwargs):
return self.__wrapper(*args, **kwargs)
@classmethod
def get_sharable_link(cls):
# for k, v in cls.queries.items():
# print(f"{k}: {v}")
return re.sub(
"&+", "&", "&".join([str(v) for k, v in cls.queries.items()])
).strip("&")
class _QueryWrapper:
ILLEGAL_CHARS = "&/=?"
def __init__(self, query: str, label: str | None = None, use_hash: bool = True):
self.query = query
self.label = label or query
self.use_hash = use_hash
self.hash_table = {}
self.key = None
def __call__(
self,
base_container: Callable,
legal_list: list[T],
default: T | list[T] | None = None,
*,
key: str | None = None,
**kwargs,
) -> T | list[T] | None:
val_from_query = st.query_params.get_all(self.query.lower())
# print(val_from_query)
legal = len(val_from_query) > 0
self.key = key or self.label
self.hash_table = {SHA1(str(v)): v for v in legal_list}
# filter out illegal values
if legal and legal_list:
val_from_query = [v for v in val_from_query if v in self.hash_table]
# print(self.label, val_from_query, legal)
if legal:
selected = [self.hash_table[v] for v in val_from_query]
elif default:
selected = default
elif self.label in st.session_state:
selected = st.session_state[self.label]
if legal_list:
if isinstance(selected, list):
selected = [v for v in selected if v in legal_list]
elif selected not in legal_list:
selected = []
else:
selected = []
if selected is None:
pass
elif len(selected) == 1 and base_container in [st.selectbox, st.radio]:
selected = selected[0]
# print(self.label, selected)
if base_container == st.checkbox:
selected = base_container(
self.label,
legal_list,
index=legal_list.index(selected) if selected in legal_list else None,
key=self.key,
**kwargs,
)
elif base_container == st.multiselect:
selected = base_container(
self.label, legal_list, default=selected, key=self.key, **kwargs
)
elif base_container == st.radio:
selected = base_container(
self.label,
legal_list,
index=legal_list.index(selected) if selected in legal_list else None,
key=self.key,
**kwargs,
)
elif base_container == st.selectbox:
selected = base_container(
self.label,
legal_list,
index=legal_list.index(selected) if selected in legal_list else None,
key=self.key,
**kwargs,
)
else:
selected = base_container(self.label, legal_list, key=self.key, **kwargs)
return st.session_state[self.key]
def __str__(self):
selected = st.session_state.get(self.key, None)
if isinstance(selected, str):
return f"{self.query.lower()}={SHA1(selected)}"
elif isinstance(selected, list):
return "&".join([f"{self.query.lower()}={SHA1(str(v))}" for v in selected])
else:
return ""