|
from datetime import datetime |
|
|
|
import streamlit as st |
|
|
|
from lib import base64_encode_image_file, config, txt2img_generate |
|
|
|
st.set_page_config( |
|
page_title=f"Text to Image - {config.title}", |
|
page_icon=config.icon, |
|
layout=config.layout, |
|
) |
|
|
|
if "api_key_bfl" not in st.session_state: |
|
st.session_state.api_key_bfl = "" |
|
|
|
if "api_key_fal" not in st.session_state: |
|
st.session_state.api_key_fal = "" |
|
|
|
if "api_key_hf" not in st.session_state: |
|
st.session_state.api_key_hf = "" |
|
|
|
if "api_key_together" not in st.session_state: |
|
st.session_state.api_key_together = "" |
|
|
|
if "running" not in st.session_state: |
|
st.session_state.running = False |
|
|
|
if "txt2img_messages" not in st.session_state: |
|
st.session_state.txt2img_messages = [] |
|
|
|
if "txt2img_seed" not in st.session_state: |
|
st.session_state.txt2img_seed = 0 |
|
|
|
st.logo(config.logo, size="small") |
|
st.sidebar.header("Settings") |
|
|
|
image_services = { |
|
service_id: service_config |
|
for service_id, service_config in config.services.items() |
|
if getattr(service_config, "image", None) |
|
} |
|
|
|
service = st.sidebar.selectbox( |
|
"Service", |
|
options=image_services.keys(), |
|
format_func=lambda x: image_services[x].name, |
|
disabled=st.session_state.running, |
|
) |
|
|
|
|
|
for service_id, service_config in image_services.items(): |
|
if service == service_id: |
|
session_key = f"api_key_{service}" |
|
api_key = service_config.api_key |
|
st.session_state[session_key] = st.sidebar.text_input( |
|
"API Key", |
|
type="password", |
|
value="" if api_key else st.session_state[session_key], |
|
disabled=bool(api_key) or st.session_state.running, |
|
help="Set by environment variable" if api_key else "Cleared on page refresh", |
|
) |
|
|
|
service_config = image_services[service] |
|
|
|
model = st.sidebar.selectbox( |
|
"Model", |
|
options=service_config.image.keys(), |
|
format_func=lambda x: service_config.image[x].name, |
|
disabled=st.session_state.running, |
|
) |
|
|
|
model_config = service_config.image[model] |
|
|
|
st.html(""" |
|
<h1>Text to Image</h1> |
|
<p>Generate an image from a text prompt.</p> |
|
""") |
|
|
|
|
|
parameters = {} |
|
for param in model_config.parameters: |
|
if param == "model": |
|
parameters[param] = model |
|
|
|
if param == "seed": |
|
parameters[param] = st.sidebar.number_input( |
|
"Seed", |
|
min_value=-1, |
|
max_value=(1 << 53) - 1, |
|
value=-1, |
|
disabled=st.session_state.running, |
|
) |
|
|
|
if param == "negative_prompt": |
|
parameters[param] = st.sidebar.text_area( |
|
"Negative Prompt", |
|
value=model_config.negative_prompt, |
|
disabled=st.session_state.running, |
|
) |
|
|
|
if param == "width": |
|
parameters[param] = st.sidebar.slider( |
|
"Width", |
|
step=64, |
|
value=model_config.width, |
|
min_value=model_config.width_range[0], |
|
max_value=model_config.width_range[1], |
|
disabled=st.session_state.running, |
|
) |
|
|
|
if param == "height": |
|
parameters[param] = st.sidebar.slider( |
|
"Height", |
|
step=64, |
|
value=model_config.height, |
|
min_value=model_config.height_range[0], |
|
max_value=model_config.height_range[1], |
|
disabled=st.session_state.running, |
|
) |
|
|
|
if param == "image_size": |
|
parameters[param] = st.sidebar.select_slider( |
|
"Image Size", |
|
options=model_config.image_sizes, |
|
value=model_config.image_size, |
|
disabled=st.session_state.running, |
|
) |
|
|
|
if param == "aspect_ratio": |
|
parameters[param] = st.sidebar.select_slider( |
|
"Aspect Ratio", |
|
options=model_config.aspect_ratios, |
|
value=model_config.aspect_ratio, |
|
disabled=st.session_state.running, |
|
) |
|
|
|
if param in ["guidance_scale", "guidance"]: |
|
parameters[param] = st.sidebar.slider( |
|
"Guidance Scale", |
|
model_config.guidance_scale_range[0], |
|
model_config.guidance_scale_range[1], |
|
model_config.guidance_scale, |
|
0.1, |
|
disabled=st.session_state.running, |
|
) |
|
|
|
if param in ["num_inference_steps", "steps"]: |
|
parameters[param] = st.sidebar.slider( |
|
"Inference Steps", |
|
model_config.num_inference_steps_range[0], |
|
model_config.num_inference_steps_range[1], |
|
model_config.num_inference_steps, |
|
1, |
|
disabled=st.session_state.running, |
|
) |
|
|
|
if param == "strength": |
|
parameters[param] = st.sidebar.slider( |
|
"Strength", |
|
model_config.strength_range[0], |
|
model_config.strength_range[1], |
|
model_config.strength, |
|
0.05, |
|
disabled=st.session_state.running, |
|
) |
|
|
|
if param in ["expand_prompt", "prompt_expansion"]: |
|
parameters[param] = st.sidebar.checkbox( |
|
"Prompt Expansion", |
|
value=False, |
|
disabled=st.session_state.running, |
|
) |
|
|
|
if param == "prompt_upsampling": |
|
parameters[param] = st.sidebar.checkbox( |
|
"Prompt Upsampling", |
|
value=False, |
|
disabled=st.session_state.running, |
|
) |
|
|
|
if param == "image_url": |
|
image_file = st.sidebar.file_uploader( |
|
"Image", |
|
type=["bmp", "gif", "jpg", "jpeg", "png", "webp"], |
|
accept_multiple_files=False, |
|
disabled=st.session_state.running, |
|
) |
|
if image_file: |
|
parameters[param] = base64_encode_image_file(image_file) |
|
|
|
|
|
for message in st.session_state.txt2img_messages: |
|
role = message["role"] |
|
with st.chat_message(role): |
|
image_container = st.empty() |
|
|
|
with image_container.container(): |
|
if role == "user": |
|
with st.expander(message["content"]): |
|
|
|
st.html(""" |
|
<style> |
|
div[data-testid="stMarkdownContainer"] p:not(:last-of-type) { margin-bottom: 0 } |
|
</style> |
|
""") |
|
filtered_parameters = [ |
|
f"`{k}`: {v}" |
|
for k, v in message["parameters"].items() |
|
if k not in config.hidden_parameters |
|
] |
|
st.markdown(f"`model`: {message['model']}\n\n" + "\n\n".join(filtered_parameters)) |
|
|
|
if role == "assistant": |
|
|
|
st.html(""" |
|
<style> |
|
div[data-testid="stImage"]:has(img[style*="max-width: 100%"]) { |
|
height: auto; |
|
max-width: 512px; |
|
} |
|
div[data-testid="stImage"] img[style*="max-width: 100%"] { |
|
border-radius: 8px; |
|
} |
|
</style> |
|
""") |
|
st.write(message["content"]) |
|
|
|
|
|
if st.session_state.txt2img_messages: |
|
button_container = st.empty() |
|
with button_container.container(): |
|
|
|
st.html(""" |
|
<style> |
|
div[data-testid="column"] { |
|
width: fit-content; |
|
min-width: 0; |
|
flex: none; |
|
} |
|
</style> |
|
""") |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
if ( |
|
st.button("β", help="Delete last generation", disabled=st.session_state.running) |
|
and len(st.session_state.txt2img_messages) >= 2 |
|
): |
|
st.session_state.txt2img_messages.pop() |
|
st.session_state.txt2img_messages.pop() |
|
st.rerun() |
|
|
|
with col2: |
|
if st.button("ποΈ", help="Clear all generations", disabled=st.session_state.running): |
|
st.session_state.txt2img_messages = [] |
|
st.session_state.txt2img_seed = 0 |
|
st.rerun() |
|
else: |
|
button_container = None |
|
|
|
|
|
|
|
if prompt := st.chat_input( |
|
"What do you want to see?", |
|
on_submit=lambda: setattr(st.session_state, "running", True), |
|
): |
|
if "seed" in parameters and parameters["seed"] >= 0: |
|
st.session_state.txt2img_seed = parameters["seed"] |
|
else: |
|
st.session_state.txt2img_seed = int(datetime.now().timestamp() * 1e6) % (1 << 53) |
|
if "seed" in parameters: |
|
parameters["seed"] = st.session_state.txt2img_seed |
|
|
|
if button_container: |
|
button_container.empty() |
|
|
|
with st.chat_message("user"): |
|
st.markdown(prompt) |
|
|
|
with st.chat_message("assistant"): |
|
with st.spinner("Running..."): |
|
if model_config.kwargs: |
|
parameters.update(model_config.kwargs) |
|
session_key = f"api_key_{service}" |
|
api_key = st.session_state[session_key] or image_services[service].api_key |
|
image = txt2img_generate(api_key, service, model, prompt, parameters) |
|
st.session_state.running = False |
|
|
|
st.session_state.txt2img_messages.append( |
|
{"role": "user", "content": prompt, "parameters": parameters, "model": model_config.name} |
|
) |
|
st.session_state.txt2img_messages.append({"role": "assistant", "content": image}) |
|
st.rerun() |
|
|