Spaces:
Runtime error
Runtime error
File size: 2,797 Bytes
8a1d0f1 2fda096 8a1d0f1 2fda096 8a1d0f1 2fda096 7562221 44df589 2fda096 51628e8 2fda096 51628e8 88468c1 51628e8 88468c1 51628e8 88468c1 06a6958 2fda096 44df589 88468c1 2fda096 8a1d0f1 2fda096 06a6958 2fda096 5ac5327 2fda096 5ac5327 2fda096 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import streamlit as st
from datasets import load_dataset
st.set_page_config(
page_icon="🧊",
layout="wide",
)
st.write(
"This is an application for viewing different generations for the same prompt. The generations vary depending on the checkpoint used and also the parameters used for the generation."
)
HF_API_TOKEN = st.secrets["HF_API_TOKEN"]
PROMPT_COLOR = "#CA437E"
def safe_text(text):
text = text.replace("\n", "<br>")
return f"<pre>{text}</pre>"
def prompt_markup_format(text):
return f'<*font color="black">{text}</*font>'
def generation_markup_format(text):
return f"<font color={PROMPT_COLOR}>{text}</pre></font>"
ds = load_dataset("bigscience/bloom-generations", use_auth_token=HF_API_TOKEN)
ds = ds["train"]
col_1, col_2 = st.columns(2)
with col_1:
possible_checkpoint = ds.unique("checkpoint")
st.markdown("<h1 style='text-align: center'>Prompt</h1>", unsafe_allow_html=True)
chosen_checkpoint = st.selectbox("Choose a checkpoint", possible_checkpoint + ["all"])
if chosen_checkpoint == "all":
ds_ckp = ds
else:
ds_ckp = ds.filter(
lambda exs: [lang == chosen_checkpoint for lang in exs["checkpoint"]], batched=True
)
possible_langs = ds.unique("lang")
chosen_lang = st.selectbox("Choose a lang", possible_langs + ["all"])
if chosen_lang == "all":
ds_lang = ds_ckp
else:
ds_lang = ds_ckp.filter(
lambda exs: [lang == chosen_lang for lang in exs["lang"]], batched=True
)
possible_prompts = ds_lang.unique("prompt")
chosen_prompt = st.selectbox("Choose a prompt", possible_prompts)
st.markdown(safe_text(chosen_prompt), unsafe_allow_html=True)
sub_ds = ds_lang.filter(
lambda exs: [prompt == chosen_prompt for prompt in exs["prompt"]], batched=True
)
with col_2:
st.markdown(
"<h1 style='text-align: center'>Generation</h1>", unsafe_allow_html=True
)
index_sample = st.number_input(
"Index of the chosen generation",
min_value=0,
max_value=len(sub_ds) - 1,
value=0,
step=1,
)
sample = sub_ds[index_sample]
generation = sample["generation"]
stop_index_sample = st.number_input(
"Stop generation at character number",
min_value=0,
max_value=len(generation),
value=len(generation),
step=1,
)
markdown_text = generation_markup_format(safe_text(generation[:stop_index_sample]))
st.markdown(markdown_text, unsafe_allow_html=True)
st.markdown(
"<h2 style='text-align: center'>Generation configuration</h2>",
unsafe_allow_html=True,
)
config = {
key: value
for key, value in sample.items()
if key not in ["prompt", "generation"]
}
config
|