Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import numpy as np | |
import random | |
import spaces | |
import torch | |
from diffusers import DiffusionPipeline | |
from PIL import Image | |
import io | |
dtype = torch.bfloat16 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe = DiffusionPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-schnell", | |
torch_dtype=dtype | |
).to(device) | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 2048 | |
import numpy as np | |
from collections import Counter | |
def get_prominent_colors(image, num_colors=5): | |
""" | |
Get the most prominent colors from an image, focusing on edges | |
""" | |
# Convert to numpy array | |
img_array = np.array(image) | |
# Create a simple edge mask using gradient magnitude | |
gradient_x = np.gradient(img_array.mean(axis=2))[1] | |
gradient_y = np.gradient(img_array.mean(axis=2))[0] | |
gradient_magnitude = np.sqrt(gradient_x**2 + gradient_y**2) | |
# Threshold to get edge pixels | |
edge_threshold = np.percentile(gradient_magnitude, 90) # Adjust percentile as needed | |
edge_mask = gradient_magnitude > edge_threshold | |
# Get colors from edge pixels | |
edge_colors = img_array[edge_mask] | |
# Convert colors to tuples for counting | |
colors = [tuple(color) for color in edge_colors] | |
# Count occurrences of each color | |
color_counts = Counter(colors) | |
# Get most common colors | |
prominent_colors = color_counts.most_common(num_colors) | |
return prominent_colors | |
def create_tshirt_preview(design_image, tshirt_color="white"): | |
""" | |
Overlay the design onto the existing t-shirt template and color match | |
""" | |
# Load the template t-shirt image | |
tshirt = Image.open('image.jpeg') | |
tshirt_width, tshirt_height = tshirt.size | |
# Convert design to PIL Image if it's not already | |
if not isinstance(design_image, Image.Image): | |
design_image = Image.fromarray(design_image) | |
# Get prominent colors from the design | |
prominent_colors = get_prominent_colors(design_image) | |
if prominent_colors: | |
# Use the most prominent color for the t-shirt | |
main_color = prominent_colors[0][0] # RGB tuple of most common color | |
else: | |
# Fallback to white if no colors found | |
main_color = (255, 255, 255) | |
# Convert design to PIL Image if it's not already | |
if not isinstance(design_image, Image.Image): | |
design_image = Image.fromarray(design_image) | |
# Resize design to fit nicely on shirt (40% of shirt width) | |
design_width = int(tshirt_width * 0.35) # Adjust this percentage as needed | |
design_height = int(design_width * design_image.size[1] / design_image.size[0]) | |
design_image = design_image.resize((design_width, design_height), Image.Resampling.LANCZOS) | |
# Calculate position to center design on shirt | |
x = (tshirt_width - design_width) // 2 | |
y = int(tshirt_height * 0.2) # Adjust this value based on your template | |
# If design has transparency (RGBA), create mask | |
if design_image.mode == 'RGBA': | |
mask = design_image.split()[3] | |
else: | |
mask = None | |
# Paste design onto shirt | |
tshirt.paste(design_image, (x, y), mask) | |
return tshirt | |
def enhance_prompt_for_tshirt(prompt, style=None): | |
"""Add specific terms to ensure good t-shirt designs.""" | |
style_terms = { | |
"minimal": ["simple geometric shapes", "clean lines", "minimalist illustration"], | |
"vintage": ["distressed effect", "retro typography", "vintage illustration"], | |
"artistic": ["hand-drawn style", "watercolor effect", "artistic illustration"], | |
"geometric": ["abstract shapes", "geometric patterns", "modern design"], | |
"typography": ["bold typography", "creative lettering", "text-based design"], | |
"realistic": ["realistic", "cinematic", "photograph"] | |
} | |
base_terms = [ | |
"create t-shirt design", | |
"with centered composition", | |
"high quality", | |
"professional design", | |
"clear background" | |
] | |
enhanced_prompt = f"{prompt}, {', '.join(base_terms)}" | |
if style and style in style_terms: | |
style_specific_terms = style_terms[style] | |
enhanced_prompt = f"{enhanced_prompt}, {', '.join(style_specific_terms)}" | |
return enhanced_prompt | |
def infer(prompt, style=None, tshirt_color="white", seed=42, randomize_seed=False, | |
width=1024, height=1024, num_inference_steps=4, | |
progress=gr.Progress(track_tqdm=True)): | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
enhanced_prompt = enhance_prompt_for_tshirt(prompt, style) | |
generator = torch.Generator().manual_seed(seed) | |
# Generate the design | |
design_image = pipe( | |
prompt=enhanced_prompt, | |
width=width, | |
height=height, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
guidance_scale=0.0 | |
).images[0] | |
# Create t-shirt preview | |
tshirt_preview = create_tshirt_preview(design_image, tshirt_color) | |
return design_image, tshirt_preview, seed | |
# Available t-shirt colors | |
TSHIRT_COLORS = { | |
"White": "#FFFFFF", | |
"Black": "#000000", | |
"Navy": "#000080", | |
"Gray": "#808080" | |
} | |
examples = [ | |
["Cool geometric mountain landscape", "minimal", "White"], | |
["Vintage motorcycle with flames", "vintage", "Black"], | |
["flamingo in scenic forset", "realistic", "White"], | |
["Adventure Starts typography", "typography", "White"] | |
] | |
styles = [ | |
"minimal", | |
"vintage", | |
"artistic", | |
"geometric", | |
"typography", | |
"realistic" | |
] | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 1200px !important; | |
padding: 20px; | |
} | |
.main-title { | |
text-align: center; | |
color: #2d3748; | |
margin-bottom: 1rem; | |
font-family: 'Poppins', sans-serif; | |
} | |
.subtitle { | |
text-align: center; | |
color: #4a5568; | |
margin-bottom: 2rem; | |
font-family: 'Inter', sans-serif; | |
font-size: 0.95rem; | |
line-height: 1.5; | |
} | |
.design-input { | |
border: 2px solid #e2e8f0; | |
border-radius: 10px; | |
padding: 12px !important; | |
margin-bottom: 1rem !important; | |
font-size: 1rem; | |
transition: all 0.3s ease; | |
} | |
.results-row { | |
display: grid; | |
grid-template-columns: 1fr 1fr; | |
gap: 20px; | |
margin-top: 20px; | |
} | |
""" | |
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown( | |
""" | |
# 👕Deradh's T-Shirt Design Generator | |
""", | |
elem_classes=["main-title"] | |
) | |
gr.Markdown( | |
""" | |
Create unique t-shirt designs using Deradh's AI. | |
Describe your design idea and select a style to generate professional-quality artwork | |
perfect for custom t-shirts. | |
""", | |
elem_classes=["subtitle"] | |
) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
prompt = gr.Text( | |
label="Design Description", | |
show_label=False, | |
max_lines=1, | |
placeholder="Describe your t-shirt design idea", | |
container=False, | |
elem_classes=["design-input"] | |
) | |
with gr.Column(scale=1): | |
style = gr.Dropdown( | |
choices=[""] + styles, | |
value="", | |
label="Style", | |
container=False | |
) | |
with gr.Column(scale=1): | |
tshirt_color = gr.Dropdown( | |
choices=list(TSHIRT_COLORS.keys()), | |
value="White", | |
label="T-Shirt Color", | |
container=False | |
) | |
run_button = gr.Button( | |
"✨ Generate", | |
scale=0, | |
elem_classes=["generate-button"] | |
) | |
with gr.Row(elem_classes=["results-row"]): | |
result = gr.Image( | |
label="Generated Design", | |
show_label=True, | |
elem_classes=["result-image"] | |
) | |
preview = gr.Image( | |
label="T-Shirt Preview", | |
show_label=True, | |
elem_classes=["preview-image"] | |
) | |
with gr.Accordion("🔧 Advanced Settings", open=False): | |
with gr.Group(): | |
seed = gr.Slider( | |
label="Design Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox( | |
label="Randomize Design", | |
value=True | |
) | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=1024, | |
) | |
height = gr.Slider( | |
label="Height", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=1024, | |
) | |
num_inference_steps = gr.Slider( | |
label="Generation Quality (Steps)", | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=4, | |
) | |
gr.Examples( | |
examples=examples, | |
fn=infer, | |
inputs=[prompt, style, tshirt_color], | |
outputs=[result, preview, seed], | |
cache_examples=True | |
) | |
gr.on( | |
triggers=[run_button.click, prompt.submit], | |
fn=infer, | |
inputs=[prompt, style, tshirt_color, seed, randomize_seed, width, height, num_inference_steps], | |
outputs=[result, preview, seed] | |
) | |
demo.launch() |