from collections import defaultdict
from enum import Enum
import operator
from typing import Callable, List, NamedTuple, Optional
import numpy as np
from trulens_eval.utils.serial import SerialModel
class ResultCategoryType(Enum):
PASS = 0
WARNING = 1
FAIL = 2
class CATEGORY:
"""
Feedback result categories for displaying purposes: pass, warning, fail, or
unknown.
"""
class Category(SerialModel):
name: str
adjective: str
threshold: float
color: str
icon: str
direction: Optional[str] = None
compare: Optional[Callable[[float, float], bool]] = None
class FeedbackDirection(NamedTuple):
name: str
ascending: bool
thresholds: List[float]
# support both directions by default
# TODO: make this configurable (per feedback definition & per app?)
directions = [
FeedbackDirection("HIGHER_IS_BETTER", False, [0, 0.6, 0.8]),
FeedbackDirection("LOWER_IS_BETTER", True, [0.2, 0.4, 1]),
]
styling = {
"PASS": dict(color="#aaffaa", icon="✅"),
"WARNING": dict(color="#ffffaa", icon="⚠️"),
"FAIL": dict(color="#ffaaaa", icon="🛑"),
}
for category_name in ResultCategoryType._member_names_:
locals()[category_name] = defaultdict(dict)
for direction in directions:
a = sorted(
zip(["low", "medium", "high"], sorted(direction.thresholds)),
key=operator.itemgetter(1),
reverse=not direction.ascending,
)
for enum, (adjective, threshold) in enumerate(a):
category_name = ResultCategoryType(enum).name
locals()[category_name][direction.name] = Category(
name=category_name.lower(),
adjective=adjective,
threshold=threshold,
direction=direction.name,
compare=operator.ge
if direction.name == "HIGHER_IS_BETTER" else operator.le,
**styling[category_name],
)
UNKNOWN = Category(
name="unknown",
adjective="unknown",
threshold=np.nan,
color="#aaaaaa",
icon="?"
)
# order matters here because `of_score` returns the first best category
ALL = [PASS, WARNING, FAIL] # not including UNKNOWN intentionally
@staticmethod
def of_score(score: float, higher_is_better: bool = True) -> Category:
direction_key = "HIGHER_IS_BETTER" if higher_is_better else "LOWER_IS_BETTER"
for cat in map(operator.itemgetter(direction_key), CATEGORY.ALL):
if cat.compare(score, cat.threshold):
return cat
return CATEGORY.UNKNOWN
default_direction = "HIGHER_IS_BETTER"
# These would be useful to include in our pages but don't yet see a way to do
# this in streamlit.
root_js = f"""
var default_pass_threshold = {CATEGORY.PASS[default_direction].threshold};
var default_warning_threshold = {CATEGORY.WARNING[default_direction].threshold};
var default_fail_threshold = {CATEGORY.FAIL[default_direction].threshold};
"""
# Not presently used. Need to figure out how to include this in streamlit pages.
root_html = f"""
"""
stmetricdelta_hidearrow = """
"""
valid_directions = ["HIGHER_IS_BETTER", "LOWER_IS_BETTER"]
cellstyle_jscode = {
k: f"""function(params) {{
let v = parseFloat(params.value);
""" + "\n".join(
f"""
if (v {'>=' if k == "HIGHER_IS_BETTER" else '<='} {cat.threshold}) {{
return {{
'color': 'black',
'backgroundColor': '{cat.color}'
}};
}}
""" for cat in map(operator.itemgetter(k), CATEGORY.ALL)
) + f"""
// i.e. not a number
return {{
'color': 'black',
'backgroundColor': '{CATEGORY.UNKNOWN.color}'
}};
}}""" for k in valid_directions
}
hide_table_row_index = """
"""