File size: 2,797 Bytes
8a1d0f1
2fda096
 
 
 
 
 
 
 
 
 
8a1d0f1
 
 
 
2fda096
8a1d0f1
 
 
 
 
 
 
 
 
 
 
 
2fda096
7562221
44df589
 
 
2fda096
 
51628e8
2fda096
51628e8
 
 
 
 
 
 
 
88468c1
 
51628e8
88468c1
51628e8
 
 
 
88468c1
06a6958
2fda096
44df589
88468c1
2fda096
 
8a1d0f1
 
2fda096
 
 
 
 
06a6958
2fda096
 
 
 
 
5ac5327
2fda096
5ac5327
 
 
 
 
 
 
 
 
2fda096
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import streamlit as st
from datasets import load_dataset

st.set_page_config(
    page_icon="🧊",
    layout="wide",
)

st.write(
    "This is an application for viewing different generations for the same prompt. The generations vary depending on the checkpoint used and also the parameters used for the generation."
)

HF_API_TOKEN = st.secrets["HF_API_TOKEN"]
PROMPT_COLOR = "#CA437E"


def safe_text(text):
    text = text.replace("\n", "<br>")
    return f"<pre>{text}</pre>"


def prompt_markup_format(text):
    return f'<*font color="black">{text}</*font>'


def generation_markup_format(text):
    return f"<font color={PROMPT_COLOR}>{text}</pre></font>"


ds = load_dataset("bigscience/bloom-generations", use_auth_token=HF_API_TOKEN)
ds = ds["train"]


col_1, col_2 = st.columns(2)
with col_1:
    possible_checkpoint = ds.unique("checkpoint")
    st.markdown("<h1 style='text-align: center'>Prompt</h1>", unsafe_allow_html=True)
    chosen_checkpoint = st.selectbox("Choose a checkpoint", possible_checkpoint + ["all"])
    if chosen_checkpoint == "all":
        ds_ckp = ds
    else:
        ds_ckp = ds.filter(
        lambda exs: [lang == chosen_checkpoint for lang in exs["checkpoint"]], batched=True
    )
    possible_langs = ds.unique("lang")  
    chosen_lang = st.selectbox("Choose a lang", possible_langs + ["all"])
    if chosen_lang == "all":
        ds_lang = ds_ckp
    else:
        ds_lang = ds_ckp.filter(
        lambda exs: [lang == chosen_lang for lang in exs["lang"]], batched=True
    )
    
    possible_prompts = ds_lang.unique("prompt")
    chosen_prompt = st.selectbox("Choose a prompt", possible_prompts)
    st.markdown(safe_text(chosen_prompt), unsafe_allow_html=True)

sub_ds = ds_lang.filter(
    lambda exs: [prompt == chosen_prompt for prompt in exs["prompt"]], batched=True
)


with col_2:
    st.markdown(
        "<h1 style='text-align: center'>Generation</h1>", unsafe_allow_html=True
    )
    index_sample = st.number_input(
        "Index of the chosen generation",
        min_value=0,
        max_value=len(sub_ds) - 1,
        value=0,
        step=1,
    )
    
    sample = sub_ds[index_sample]
    generation = sample["generation"]
    stop_index_sample = st.number_input(
        "Stop generation at character number",
        min_value=0,
        max_value=len(generation),
        value=len(generation),
        step=1,
    )
    markdown_text = generation_markup_format(safe_text(generation[:stop_index_sample]))
    st.markdown(markdown_text, unsafe_allow_html=True)
    st.markdown(
        "<h2 style='text-align: center'>Generation configuration</h2>",
        unsafe_allow_html=True,
    )
    config = {
        key: value
        for key, value in sample.items()
        if key not in ["prompt", "generation"]
    }
    config