Gopalag's picture
Update app.py
97a01d0 verified
import gradio as gr
import numpy as np
import random
import spaces
import torch
from diffusers import DiffusionPipeline
from PIL import Image
import io
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=dtype
).to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
import numpy as np
from collections import Counter
def get_prominent_colors(image, num_colors=5):
"""
Get the most prominent colors from an image, focusing on edges
"""
# Convert to numpy array
img_array = np.array(image)
# Create a simple edge mask using gradient magnitude
gradient_x = np.gradient(img_array.mean(axis=2))[1]
gradient_y = np.gradient(img_array.mean(axis=2))[0]
gradient_magnitude = np.sqrt(gradient_x**2 + gradient_y**2)
# Threshold to get edge pixels
edge_threshold = np.percentile(gradient_magnitude, 90) # Adjust percentile as needed
edge_mask = gradient_magnitude > edge_threshold
# Get colors from edge pixels
edge_colors = img_array[edge_mask]
# Convert colors to tuples for counting
colors = [tuple(color) for color in edge_colors]
# Count occurrences of each color
color_counts = Counter(colors)
# Get most common colors
prominent_colors = color_counts.most_common(num_colors)
return prominent_colors
def create_tshirt_preview(design_image, tshirt_color="white"):
"""
Overlay the design onto the existing t-shirt template and color match
"""
# Load the template t-shirt image
tshirt = Image.open('image.jpeg')
tshirt_width, tshirt_height = tshirt.size
# Convert design to PIL Image if it's not already
if not isinstance(design_image, Image.Image):
design_image = Image.fromarray(design_image)
# Get prominent colors from the design
prominent_colors = get_prominent_colors(design_image)
if prominent_colors:
# Use the most prominent color for the t-shirt
main_color = prominent_colors[0][0] # RGB tuple of most common color
else:
# Fallback to white if no colors found
main_color = (255, 255, 255)
# Convert design to PIL Image if it's not already
if not isinstance(design_image, Image.Image):
design_image = Image.fromarray(design_image)
# Resize design to fit nicely on shirt (40% of shirt width)
design_width = int(tshirt_width * 0.35) # Adjust this percentage as needed
design_height = int(design_width * design_image.size[1] / design_image.size[0])
design_image = design_image.resize((design_width, design_height), Image.Resampling.LANCZOS)
# Calculate position to center design on shirt
x = (tshirt_width - design_width) // 2
y = int(tshirt_height * 0.2) # Adjust this value based on your template
# If design has transparency (RGBA), create mask
if design_image.mode == 'RGBA':
mask = design_image.split()[3]
else:
mask = None
# Paste design onto shirt
tshirt.paste(design_image, (x, y), mask)
return tshirt
def enhance_prompt_for_tshirt(prompt, style=None):
"""Add specific terms to ensure good t-shirt designs."""
style_terms = {
"minimal": ["simple geometric shapes", "clean lines", "minimalist illustration"],
"vintage": ["distressed effect", "retro typography", "vintage illustration"],
"artistic": ["hand-drawn style", "watercolor effect", "artistic illustration"],
"geometric": ["abstract shapes", "geometric patterns", "modern design"],
"typography": ["bold typography", "creative lettering", "text-based design"],
"realistic": ["realistic", "cinematic", "photograph"]
}
base_terms = [
"create t-shirt design",
"with centered composition",
"high quality",
"professional design",
"clear background"
]
enhanced_prompt = f"{prompt}, {', '.join(base_terms)}"
if style and style in style_terms:
style_specific_terms = style_terms[style]
enhanced_prompt = f"{enhanced_prompt}, {', '.join(style_specific_terms)}"
return enhanced_prompt
@spaces.GPU()
def infer(prompt, style=None, tshirt_color="white", seed=42, randomize_seed=False,
width=1024, height=1024, num_inference_steps=4,
progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
enhanced_prompt = enhance_prompt_for_tshirt(prompt, style)
generator = torch.Generator().manual_seed(seed)
# Generate the design
design_image = pipe(
prompt=enhanced_prompt,
width=width,
height=height,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=0.0
).images[0]
# Create t-shirt preview
tshirt_preview = create_tshirt_preview(design_image, tshirt_color)
return design_image, tshirt_preview, seed
# Available t-shirt colors
TSHIRT_COLORS = {
"White": "#FFFFFF",
"Black": "#000000",
"Navy": "#000080",
"Gray": "#808080"
}
examples = [
["Cool geometric mountain landscape", "minimal", "White"],
["Vintage motorcycle with flames", "vintage", "Black"],
["flamingo in scenic forset", "realistic", "White"],
["Adventure Starts typography", "typography", "White"]
]
styles = [
"minimal",
"vintage",
"artistic",
"geometric",
"typography",
"realistic"
]
css = """
#col-container {
margin: 0 auto;
max-width: 1200px !important;
padding: 20px;
}
.main-title {
text-align: center;
color: #2d3748;
margin-bottom: 1rem;
font-family: 'Poppins', sans-serif;
}
.subtitle {
text-align: center;
color: #4a5568;
margin-bottom: 2rem;
font-family: 'Inter', sans-serif;
font-size: 0.95rem;
line-height: 1.5;
}
.design-input {
border: 2px solid #e2e8f0;
border-radius: 10px;
padding: 12px !important;
margin-bottom: 1rem !important;
font-size: 1rem;
transition: all 0.3s ease;
}
.results-row {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 20px;
margin-top: 20px;
}
"""
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(
"""
# 👕Deradh's T-Shirt Design Generator
""",
elem_classes=["main-title"]
)
gr.Markdown(
"""
Create unique t-shirt designs using Deradh's AI.
Describe your design idea and select a style to generate professional-quality artwork
perfect for custom t-shirts.
""",
elem_classes=["subtitle"]
)
with gr.Row():
with gr.Column(scale=2):
prompt = gr.Text(
label="Design Description",
show_label=False,
max_lines=1,
placeholder="Describe your t-shirt design idea",
container=False,
elem_classes=["design-input"]
)
with gr.Column(scale=1):
style = gr.Dropdown(
choices=[""] + styles,
value="",
label="Style",
container=False
)
with gr.Column(scale=1):
tshirt_color = gr.Dropdown(
choices=list(TSHIRT_COLORS.keys()),
value="White",
label="T-Shirt Color",
container=False
)
run_button = gr.Button(
"✨ Generate",
scale=0,
elem_classes=["generate-button"]
)
with gr.Row(elem_classes=["results-row"]):
result = gr.Image(
label="Generated Design",
show_label=True,
elem_classes=["result-image"]
)
preview = gr.Image(
label="T-Shirt Preview",
show_label=True,
elem_classes=["preview-image"]
)
with gr.Accordion("🔧 Advanced Settings", open=False):
with gr.Group():
seed = gr.Slider(
label="Design Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(
label="Randomize Design",
value=True
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
num_inference_steps = gr.Slider(
label="Generation Quality (Steps)",
minimum=1,
maximum=50,
step=1,
value=4,
)
gr.Examples(
examples=examples,
fn=infer,
inputs=[prompt, style, tshirt_color],
outputs=[result, preview, seed],
cache_examples=True
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[prompt, style, tshirt_color, seed, randomize_seed, width, height, num_inference_steps],
outputs=[result, preview, seed]
)
demo.launch()