wybxc's picture
disable fp16
f9e11a1 unverified
import streamlit as st
import torch
from diffusers import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
def void(*args, **kwargs):
pass
st.title("AI 元火娘")
with st.sidebar:
model = st.selectbox("Model Name", [
"wybxc/yanhuo-v1-dreambooth",
"wybxc/yanyuan-v1-dreambooth",
"wybxc/yuanhuo-v1-dreambooth",
"<Custom>"
])
if model == "<Custom>":
model = st.text_input("Model Path", "").strip()
# Caching model
if 'model' not in st.session_state:
st.session_state.model = model
if 'pipeline' not in st.session_state:
st.session_state.pipeline = None
if model != st.session_state.model or st.session_state.pipeline is None:
if model:
with st.spinner("Loading Model..."):
pipeline = StableDiffusionPipeline.from_pretrained(model)
assert type(pipeline) is StableDiffusionPipeline
if torch.cuda.is_available():
pipeline = pipeline.to("cuda")
st.session_state.model = model
st.session_state.pipeline = pipeline
else:
pipeline = None
else:
pipeline = st.session_state.pipeline
assert type(pipeline) is StableDiffusionPipeline
prompt = st.text_area("Prompt", "(yanhuo), 1girl, masterpiece, best quality, "
"white hair, ahoge, snowy street, [smile], dynamic angle, full body, "
"[blue eyes], flat chest, cinematic light")
negative_prompt = st.text_area("Negative Prompt", "lowres, bad anatomy, bad hands, "
"text, error, missing fingers, extra digit, fewer digits, cropped, "
"worst quality, low quality, normal quality, jpeg artifacts, signature, "
"watermark, username, blurry")
with st.sidebar:
height = st.slider("Height", 256, 1024, 512, 64)
width = st.slider("Width", 256, 1024, 512, 64)
steps = st.slider("Steps", 1, 100, 20, 1)
if pipeline and st.button("Generate"):
progress = st.progress(0)
result = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=steps,
callback=lambda s, *_: void(progress.progress(s / steps))
)
assert type(result) is StableDiffusionPipelineOutput
image = result.images[0]
progress.progress(1.0)
st.image(image)