File size: 4,182 Bytes
dec332b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from collections import defaultdict
from enum import Enum
import operator
from typing import Callable, List, NamedTuple, Optional

import numpy as np

from trulens_eval.utils.serial import SerialModel


class ResultCategoryType(Enum):
    PASS = 0
    WARNING = 1
    FAIL = 2


class CATEGORY:
    """
    Feedback result categories for displaying purposes: pass, warning, fail, or
    unknown.
    """

    class Category(SerialModel):
        name: str
        adjective: str
        threshold: float
        color: str
        icon: str
        direction: Optional[str] = None
        compare: Optional[Callable[[float, float], bool]] = None

    class FeedbackDirection(NamedTuple):
        name: str
        ascending: bool
        thresholds: List[float]

    # support both directions by default
    # TODO: make this configurable (per feedback definition & per app?)
    directions = [
        FeedbackDirection("HIGHER_IS_BETTER", False, [0, 0.6, 0.8]),
        FeedbackDirection("LOWER_IS_BETTER", True, [0.2, 0.4, 1]),
    ]

    styling = {
        "PASS": dict(color="#aaffaa", icon="✅"),
        "WARNING": dict(color="#ffffaa", icon="⚠️"),
        "FAIL": dict(color="#ffaaaa", icon="🛑"),
    }

    for category_name in ResultCategoryType._member_names_:
        locals()[category_name] = defaultdict(dict)

    for direction in directions:
        a = sorted(
            zip(["low", "medium", "high"], sorted(direction.thresholds)),
            key=operator.itemgetter(1),
            reverse=not direction.ascending,
        )

        for enum, (adjective, threshold) in enumerate(a):
            category_name = ResultCategoryType(enum).name
            locals()[category_name][direction.name] = Category(
                name=category_name.lower(),
                adjective=adjective,
                threshold=threshold,
                direction=direction.name,
                compare=operator.ge
                if direction.name == "HIGHER_IS_BETTER" else operator.le,
                **styling[category_name],
            )

    UNKNOWN = Category(
        name="unknown",
        adjective="unknown",
        threshold=np.nan,
        color="#aaaaaa",
        icon="?"
    )

    # order matters here because `of_score` returns the first best category
    ALL = [PASS, WARNING, FAIL]  # not including UNKNOWN intentionally

    @staticmethod
    def of_score(score: float, higher_is_better: bool = True) -> Category:
        direction_key = "HIGHER_IS_BETTER" if higher_is_better else "LOWER_IS_BETTER"

        for cat in map(operator.itemgetter(direction_key), CATEGORY.ALL):
            if cat.compare(score, cat.threshold):
                return cat

        return CATEGORY.UNKNOWN


default_direction = "HIGHER_IS_BETTER"

# These would be useful to include in our pages but don't yet see a way to do
# this in streamlit.
root_js = f"""
    var default_pass_threshold = {CATEGORY.PASS[default_direction].threshold};
    var default_warning_threshold = {CATEGORY.WARNING[default_direction].threshold};
    var default_fail_threshold = {CATEGORY.FAIL[default_direction].threshold};
"""

# Not presently used. Need to figure out how to include this in streamlit pages.
root_html = f"""
<script>
    {root_js}
</script>
"""

stmetricdelta_hidearrow = """
    <style> [data-testid="stMetricDelta"] svg { display: none; } </style>
    """

valid_directions = ["HIGHER_IS_BETTER", "LOWER_IS_BETTER"]

cellstyle_jscode = {
    k: f"""function(params) {{
        let v = parseFloat(params.value);
        """ + "\n".join(
        f"""
        if (v {'>=' if k == "HIGHER_IS_BETTER" else '<='} {cat.threshold}) {{
            return {{
                'color': 'black',
                'backgroundColor': '{cat.color}'
            }};
        }}
    """ for cat in map(operator.itemgetter(k), CATEGORY.ALL)
    ) + f"""
        // i.e. not a number
        return {{
            'color': 'black',
            'backgroundColor': '{CATEGORY.UNKNOWN.color}'
        }};
    }}""" for k in valid_directions
}

hide_table_row_index = """
    <style>
        thead tr th:first-child {display:none}
        tbody th {display:none}
    </style>
    """