Safetensors
aredden's picture
Small fixes & clean up
289aa1f
raw
history blame
4.28 kB
import torch
from flux_pipeline import FluxPipeline
import gradio as gr # type: ignore
from PIL import Image
def create_demo(
config_path: str,
):
generator = FluxPipeline.load_pipeline_from_config_path(config_path)
def generate_image(
prompt,
width,
height,
num_steps,
guidance,
seed,
init_image,
image2image_strength,
add_sampling_metadata,
):
seed = int(seed)
if seed == -1:
seed = None
out = generator.generate(
prompt,
width,
height,
num_steps=num_steps,
guidance=guidance,
seed=seed,
init_image=init_image,
strength=image2image_strength,
silent=False,
num_images=1,
return_seed=True,
)
image_bytes = out[0]
return Image.open(image_bytes), str(out[1]), None
is_schnell = generator.config.version == "flux-schnell"
with gr.Blocks() as demo:
gr.Markdown(f"# Flux Image Generation Demo - Model: {generator.config.version}")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
value='a photo of a forest with mist swirling around the tree trunks. The word "FLUX" is painted over it in big, red brush strokes with visible texture',
)
do_img2img = gr.Checkbox(
label="Image to Image", value=False, interactive=not is_schnell
)
init_image = gr.Image(label="Input Image", visible=False)
image2image_strength = gr.Slider(
0.0, 1.0, 0.8, step=0.1, label="Noising strength", visible=False
)
with gr.Accordion("Advanced Options", open=False):
width = gr.Slider(128, 8192, 1152, step=16, label="Width")
height = gr.Slider(128, 8192, 640, step=16, label="Height")
num_steps = gr.Slider(
1, 50, 4 if is_schnell else 20, step=1, label="Number of steps"
)
guidance = gr.Slider(
1.0,
10.0,
3.5,
step=0.1,
label="Guidance",
interactive=not is_schnell,
)
seed = gr.Textbox(-1, label="Seed (-1 for random)")
add_sampling_metadata = gr.Checkbox(
label="Add sampling parameters to metadata?", value=True
)
generate_btn = gr.Button("Generate")
with gr.Column(min_width="960px"):
output_image = gr.Image(label="Generated Image")
seed_output = gr.Number(label="Used Seed")
warning_text = gr.Textbox(label="Warning", visible=False)
# download_btn = gr.File(label="Download full-resolution")
def update_img2img(do_img2img):
return {
init_image: gr.update(visible=do_img2img),
image2image_strength: gr.update(visible=do_img2img),
}
do_img2img.change(
update_img2img, do_img2img, [init_image, image2image_strength]
)
generate_btn.click(
fn=generate_image,
inputs=[
prompt,
width,
height,
num_steps,
guidance,
seed,
init_image,
image2image_strength,
add_sampling_metadata,
],
outputs=[output_image, seed_output, warning_text],
)
return demo
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Flux")
parser.add_argument(
"--config", type=str, default="configs/config-dev.json", help="Config file path"
)
parser.add_argument(
"--share", action="store_true", help="Create a public link to your demo"
)
args = parser.parse_args()
demo = create_demo(args.config)
demo.launch(share=args.share)