Genmoji / app.py
SpyC0der77's picture
Update app.py
104c5a4 verified
raw
history blame
4.94 kB
import gradio as gr
import torch
from diffusers import FluxPipeline
import huggingface_hub
from huggingface_hub import InferenceClient
import os
huggingface_hub.login(token=os.getenv("HUGGINGFACE_API_TOKEN"))
# Initialize the Flux pipeline
def initialize_flux_pipeline():
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
pipe.load_lora_weights("EvanZhouDev/open-genmoji", weight_name="flux-dev.safetensors")
return pipe
flux_pipeline = initialize_flux_pipeline()
# Initialize the language model client
llm_client = InferenceClient("Qwen/Qwen2.5-72B-Instruct", token=os.getenv("HUGGINGFACE_API_TOKEN"))
# Function to refine the prompt
def refine_prompt(original_prompt):
messages = [
{
"role": "system",
"content": (
"You are helping create a prompt for a Emoji generation image model. An emoji must be easily "
"interpreted when small so details must be exaggerated to be clear. Your goal is to use descriptions "
"to achieve this.\n\nYou will receive a user description, and you must rephrase it to consist of "
"short phrases separated by periods, adding detail to everything the user provides.\n\nAdd describe "
"the color of all parts or components of the emoji. Unless otherwise specified by the user, do not "
"describe people. Do not describe the background of the image. Your output should be in the format:\n\n"
"```emoji of {description}. {addon phrases}. 3D lighting. no cast shadows.```\n\nThe description "
"should be a 1 sentence of your interpretation of the emoji. Then, you may choose to add addon phrases."
" You must use the following in the given scenarios:\n\n- \"cute.\": If generating anything that's not "
"an object, and also not a human\n- \"enlarged head in cartoon style.\": ONLY animals\n- \"head is "
"turned towards viewer.\": ONLY humans or animals\n- \"detailed texture.\": ONLY objects\n\nFurther "
"addon phrases may be added to ensure the clarity of the emoji."
),
},
{"role": "user", "content": original_prompt},
]
completion = llm_client.chat_completion(messages, max_tokens=100)
refined = completion["choices"][0]["message"]["content"].strip()
return refined
# Define the process function
def process(prompt, guidance_scale, num_inference_steps, height, width, seed):
print(f"Original Prompt: {prompt}")
# Refine the prompt
try:
refined_prompt = refine_prompt(prompt)
print(f"Refined Prompt: {refined_prompt}")
except Exception as e:
return f"Error refining prompt: {str(e)}"
# Set the random generator seed
generator = torch.Generator(device="cuda").manual_seed(seed)
try:
# Generate the image
output = flux_pipeline(
prompt=refined_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
height=height,
width=width,
generator=generator,
)
image = output.images[0]
return image
except Exception as e:
return f"Error generating image: {str(e)}"
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Flux Text-to-Image Generator with Prompt Refinement")
# User inputs
with gr.Row():
prompt_input = gr.Textbox(label="Enter a Prompt", placeholder="Describe your image")
guidance_scale_input = gr.Slider(
label="Guidance Scale", minimum=1.0, maximum=10.0, value=3.5, step=0.1
)
with gr.Row():
num_inference_steps_input = gr.Slider(
label="Inference Steps", minimum=1, maximum=100, value=50, step=1
)
seed_input = gr.Number(label="Seed", value=42, precision=0)
with gr.Row():
height_input = gr.Slider(label="Height", minimum=256, maximum=2048, value=768, step=64)
width_input = gr.Slider(label="Width", minimum=256, maximum=2048, value=1360, step=64)
# Output components
refined_prompt_output = gr.Textbox(label="Refined Prompt", interactive=False)
image_output = gr.Image(label="Generated Image")
# Button to generate the image
generate_button = gr.Button("Generate Image")
# Define button click behavior
generate_button.click(
fn=lambda prompt, *args: (refine_prompt(prompt), process(prompt, *args)),
inputs=[
prompt_input,
guidance_scale_input,
num_inference_steps_input,
height_input,
width_input,
seed_input,
],
outputs=[refined_prompt_output, image_output],
)
# Launch the app
if __name__ == "__main__":
demo.launch(show_error=True)