File size: 2,762 Bytes
87fe461
 
6684d9c
a467828
87fe461
98dff60
 
 
87fe461
6e61d26
98dff60
 
 
 
 
6e61d26
87fe461
 
bcb8af1
75161a8
17d778a
6e61d26
87fe461
 
 
 
6e61d26
87fe461
6e61d26
 
98dff60
 
 
 
 
6e61d26
 
bcb8af1
 
6e61d26
98dff60
 
 
a467828
9826c29
98dff60
 
 
9826c29
0c196bb
 
 
 
 
 
 
 
 
a467828
 
17d778a
75161a8
6e61d26
 
 
98dff60
 
6e61d26
87fe461
 
 
1d02673
 
 
 
 
 
 
17d778a
bcb8af1
17d778a
1d02673
ab818d6
 
17d778a
ab818d6
98dff60
 
 
bcb8af1
 
 
87fe461
 
75161a8
 
6e61d26
52a0611
6e61d26
a467828
6e61d26
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import streamlit as st

# TODO: rename and refactor everything


def select_num_interval(
    param_name: str, limits_list: list, defaults, n_for_hash, **kwargs
):
    st.sidebar.subheader(param_name)
    min_max_interval = st.sidebar.slider(
        "",
        limits_list[0],
        limits_list[1],
        defaults,
        key=hash(param_name + str(n_for_hash)),
    )
    return min_max_interval


def select_several_nums(
    param_name, subparam_names, limits_list, defaults_list, n_for_hash, **kwargs
):
    st.sidebar.subheader(param_name)
    result = []
    assert len(limits_list) == len(defaults_list)
    assert len(subparam_names) == len(defaults_list)

    for name, limits, defaults in zip(subparam_names, limits_list, defaults_list):
        result.append(
            st.sidebar.slider(
                name,
                limits[0],
                limits[1],
                defaults,
                key=hash(param_name + name + str(n_for_hash)),
            )
        )
    return tuple(result)


def select_min_max(
    param_name, limits_list, defaults_list, n_for_hash, min_diff=0, **kwargs
):
    assert len(param_name) == 2
    result = list(
        select_num_interval(
            " & ".join(param_name), limits_list, defaults_list, n_for_hash
        )
    )
    if result[1] - result[0] < min_diff:
        diff = min_diff - result[1] + result[0]
        if result[1] + diff <= limits_list[1]:
            result[1] = result[1] + diff
        elif result[0] - diff >= limits_list[0]:
            result[0] = result[0] - diff
        else:
            result = limits_list
    return tuple(result)


def select_RGB(param_name, n_for_hash, **kwargs):
    result = select_several_nums(
        param_name,
        subparam_names=["Red", "Green", "Blue"],
        limits_list=[[0, 255], [0, 255], [0, 255]],
        defaults_list=[0, 0, 0],
        n_for_hash=n_for_hash,
    )
    return tuple(result)


def replace_none(string):
    if string == "None":
        return None
    else:
        return string


def select_radio(param_name, options_list, n_for_hash, **kwargs):
    st.sidebar.subheader(param_name)
    result = st.sidebar.radio("", options_list, key=hash(param_name + str(n_for_hash)))
    return replace_none(result)


def select_checkbox(param_name, defaults, n_for_hash, **kwargs):
    st.sidebar.subheader(param_name)
    result = st.sidebar.checkbox(
        "True", defaults, key=hash(param_name + str(n_for_hash))
    )
    return result


# dict from param name to function showing this param
param2func = {
    "num_interval": select_num_interval,
    "several_nums": select_several_nums,
    "radio": select_radio,
    "rgb": select_RGB,
    "checkbox": select_checkbox,
    "min_max": select_min_max,
}