import gradio as gr
import numpy as np
import random
import spaces
import torch
import time
import os
from diffusers import DiffusionPipeline
from custom_pipeline import FLUXPipelineWithIntermediateOutputs
from transformers import pipeline
# 번역 모델 설정 (CPU 사용)
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cpu")
# 상수 정의
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
DEFAULT_WIDTH = 1024
DEFAULT_HEIGHT = 1024
DEFAULT_INFERENCE_STEPS = 1
GPU_DURATION = 15 # GPU 할당 시간 축소
# 모델 설정
def setup_model():
dtype = torch.float16
pipe = FLUXPipelineWithIntermediateOutputs.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=dtype
).to("cuda")
return pipe
pipe = setup_model()
# 메뉴 레이블
labels = {
"Generated Image": "생성된 이미지",
"Prompt": "프롬프트",
"Enhance Image": "이미지 향상",
"Advanced Options": "고급 설정",
"Seed": "시드",
"Randomize Seed": "랜덤 시드",
"Width": "너비",
"Height": "높이",
"Inference Steps": "추론 단계",
"Inspiration Gallery": "영감 갤러리"
}
def translate_if_korean(text):
"""한글 텍스트를 영어로 안전하게 번역"""
try:
if any('\u3131' <= char <= '\u3163' or '\uac00' <= char <= '\ud7a3' for char in text):
return translator(text)[0]['translation_text']
return text
except Exception as e:
print(f"번역 오류: {e}")
return text
# 이미지 생성 함수
@spaces.GPU(duration=GPU_DURATION)
def generate_image(prompt, seed=None, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT,
randomize_seed=True, num_inference_steps=DEFAULT_INFERENCE_STEPS):
try:
# 입력값 검증
if not isinstance(seed, (int, type(None))):
seed = None
randomize_seed = True
prompt = translate_if_korean(prompt)
if seed is None or randomize_seed:
seed = random.randint(0, MAX_SEED)
# 크기 유효성 검사
width = min(max(256, width), MAX_IMAGE_SIZE)
height = min(max(256, height), MAX_IMAGE_SIZE)
generator = torch.Generator().manual_seed(seed)
start_time = time.time()
with torch.cuda.amp.autocast():
for img in pipe.generate_images(
prompt=prompt,
guidance_scale=0,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator
):
latency = f"처리 시간: {(time.time()-start_time):.2f} 초"
# CUDA 캐시 정리
if torch.cuda.is_available():
torch.cuda.empty_cache()
yield img, seed, latency
except Exception as e:
print(f"이미지 생성 오류: {e}")
yield None, seed, f"오류: {str(e)}"
# 예제 이미지 생성
def generate_example_image(prompt):
try:
return next(generate_image(prompt, randomize_seed=True))
except Exception as e:
print(f"예제 생성 오류: {e}")
return None, None, f"오류: {str(e)}"
# Example prompts
examples = [
"비너 슈니첼의 애니메이션 일러스트레이션",
"A steampunk owl wearing Victorian-era clothing and reading a mechanical book",
"A floating island made of books with waterfalls of knowledge cascading down",
"A bioluminescent forest where mushrooms glow like neon signs in a cyberpunk city",
"An ancient temple being reclaimed by nature, with robots performing archaeology",
"A cosmic coffee shop where baristas are constellations serving drinks made of stardust"
]
css = """
footer {
visibility: hidden;
}
"""
def create_snow_effect():
# CSS 스타일 정의
snow_css = """
@keyframes snowfall {
0% {
transform: translateY(-10vh) translateX(0);
opacity: 1;
}
100% {
transform: translateY(100vh) translateX(100px);
opacity: 0.3;
}
}
.snowflake {
position: fixed;
color: white;
font-size: 1.5em;
user-select: none;
z-index: 1000;
pointer-events: none;
animation: snowfall linear infinite;
}
"""
# JavaScript 코드 정의
snow_js = """
function createSnowflake() {
const snowflake = document.createElement('div');
snowflake.innerHTML = '❄';
snowflake.className = 'snowflake';
snowflake.style.left = Math.random() * 100 + 'vw';
snowflake.style.animationDuration = Math.random() * 3 + 2 + 's';
snowflake.style.opacity = Math.random();
document.body.appendChild(snowflake);
setTimeout(() => {
snowflake.remove();
}, 5000);
}
setInterval(createSnowflake, 200);
"""
# CSS와 JavaScript를 결합한 HTML
snow_html = f"""
"""
return gr.HTML(snow_html)
# Gradio 앱에서 사용할 때:
# with app: 아래에
# Gradio UI 구성
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
create_snow_effect()
with gr.Column(elem_id="app-container"):
with gr.Row():
with gr.Column(scale=3):
result = gr.Image(label=labels["Generated Image"],
show_label=False,
interactive=False)
with gr.Column(scale=1):
prompt = gr.Text(
label=labels["Prompt"],
placeholder="생성하고 싶은 이미지를 설명해주세요...",
lines=3,
show_label=False,
container=False,
)
enhanceBtn = gr.Button(f"🚀 {labels['Enhance Image']}")
with gr.Column(labels["Advanced Options"]):
with gr.Row():
latency = gr.Text(show_label=False)
with gr.Row():
seed = gr.Number(
label=labels["Seed"],
value=42,
precision=0,
minimum=0,
maximum=MAX_SEED
)
randomize_seed = gr.Checkbox(
label=labels["Randomize Seed"],
value=True
)
with gr.Row():
width = gr.Slider(
label=labels["Width"],
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=DEFAULT_WIDTH
)
height = gr.Slider(
label=labels["Height"],
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=DEFAULT_HEIGHT
)
num_inference_steps = gr.Slider(
label=labels["Inference Steps"],
minimum=1,
maximum=4,
step=1,
value=DEFAULT_INFERENCE_STEPS
)
with gr.Row():
gr.Markdown(f"### 🌟 {labels['Inspiration Gallery']}")
with gr.Row():
gr.Examples(
examples=examples,
fn=generate_example_image,
inputs=[prompt],
outputs=[result, seed],
cache_examples=False
)
# 이벤트 처리
def validated_generate(*args):
try:
return next(generate_image(*args))
except Exception as e:
print(f"검증 생성 오류: {e}")
return None, args[1], f"오류: {str(e)}"
enhanceBtn.click(
fn=generate_image,
inputs=[prompt, seed, width, height],
outputs=[result, seed, latency],
show_progress="hidden",
show_api=False,
queue=False
)
gr.on(
triggers=[prompt.input, width.input, height.input, num_inference_steps.input],
fn=validated_generate,
inputs=[prompt, seed, width, height, randomize_seed, num_inference_steps],
outputs=[result, seed, latency],
show_progress="hidden",
show_api=False,
trigger_mode="always_last",
queue=False
)
if __name__ == "__main__":
demo.launch()