import gradio as gr import torch from diffusers import FluxPipeline import huggingface_hub from huggingface_hub import InferenceClient import os huggingface_hub.login(token=os.getenv("HUGGINGFACE_API_TOKEN")) # Initialize the Flux pipeline def initialize_flux_pipeline(): pipe = FluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 ) pipe.enable_model_cpu_offload() pipe.load_lora_weights("EvanZhouDev/open-genmoji", weight_name="flux-dev.safetensors") return pipe flux_pipeline = initialize_flux_pipeline() # Initialize the language model client llm_client = InferenceClient("Qwen/Qwen2.5-72B-Instruct", token=os.getenv("HUGGINGFACE_API_TOKEN")) # Function to refine the prompt def refine_prompt(original_prompt): messages = [ { "role": "system", "content": ( "You are helping create a prompt for a Emoji generation image model. An emoji must be easily " "interpreted when small so details must be exaggerated to be clear. Your goal is to use descriptions " "to achieve this.\n\nYou will receive a user description, and you must rephrase it to consist of " "short phrases separated by periods, adding detail to everything the user provides.\n\nAdd describe " "the color of all parts or components of the emoji. Unless otherwise specified by the user, do not " "describe people. Do not describe the background of the image. Your output should be in the format:\n\n" "```emoji of {description}. {addon phrases}. 3D lighting. no cast shadows.```\n\nThe description " "should be a 1 sentence of your interpretation of the emoji. Then, you may choose to add addon phrases." " You must use the following in the given scenarios:\n\n- \"cute.\": If generating anything that's not " "an object, and also not a human\n- \"enlarged head in cartoon style.\": ONLY animals\n- \"head is " "turned towards viewer.\": ONLY humans or animals\n- \"detailed texture.\": ONLY objects\n\nFurther " "addon phrases may be added to ensure the clarity of the emoji." ), }, {"role": "user", "content": original_prompt}, ] completion = llm_client.chat_completion(messages, max_tokens=100) refined = completion["choices"][0]["message"]["content"].strip() return refined # Define the process function def process(prompt, guidance_scale, num_inference_steps, height, width, seed): print(f"Original Prompt: {prompt}") # Refine the prompt try: refined_prompt = refine_prompt(prompt) print(f"Refined Prompt: {refined_prompt}") except Exception as e: return f"Error refining prompt: {str(e)}" # Set the random generator seed generator = torch.Generator(device="cuda").manual_seed(seed) try: # Generate the image output = flux_pipeline( prompt=refined_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, height=height, width=width, generator=generator, ) image = output.images[0] return image except Exception as e: return f"Error generating image: {str(e)}" # Create the Gradio interface with gr.Blocks() as demo: gr.Markdown("# Flux Text-to-Image Generator with Prompt Refinement") # User inputs with gr.Row(): prompt_input = gr.Textbox(label="Enter a Prompt", placeholder="Describe your image") guidance_scale_input = gr.Slider( label="Guidance Scale", minimum=1.0, maximum=10.0, value=3.5, step=0.1 ) with gr.Row(): num_inference_steps_input = gr.Slider( label="Inference Steps", minimum=1, maximum=100, value=50, step=1 ) seed_input = gr.Number(label="Seed", value=42, precision=0) with gr.Row(): height_input = gr.Slider(label="Height", minimum=256, maximum=2048, value=768, step=64) width_input = gr.Slider(label="Width", minimum=256, maximum=2048, value=1360, step=64) # Output components refined_prompt_output = gr.Textbox(label="Refined Prompt", interactive=False) image_output = gr.Image(label="Generated Image") # Button to generate the image generate_button = gr.Button("Generate Image") # Define button click behavior generate_button.click( fn=lambda prompt, *args: (refine_prompt(prompt), process(prompt, *args)), inputs=[ prompt_input, guidance_scale_input, num_inference_steps_input, height_input, width_input, seed_input, ], outputs=[refined_prompt_output, image_output], ) # Launch the app if __name__ == "__main__": demo.launch(show_error=True)