sculpt / app.py
ds1david's picture
New logic
1eb87a5
raw
history blame
3.34 kB
import gradio as gr
import torch
import jax
import numpy as np
from PIL import Image
from diffusers import StableDiffusionXLImg2ImgPipeline
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
from super_resolve import process as thera_process # Assume imports do Thera
# Configurações
DEVICE = "cpu" # ou "cuda" se disponível
JAX_DEVICE = jax.devices("cpu")[0] # Usar CPU para JAX
# 1. Carregar modelos do Thera (EDSR/RDN)
# (Implementar conforme código original do Thera)
model_edsr, params_edsr = None, None # Carregar usando pickle/HF Hub
# 2. Carregar SDXL Img2Img + LoRA
print("Carregando SDXL Img2Img com LoRA...")
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float32
).to(DEVICE)
pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
# 3. Carregar modelo de profundidade
print("Carregando DPT...")
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(DEVICE)
def enhance_depth_map(depth_arr):
depth_normalized = (depth_arr - depth_arr.min()) / (depth_arr.max() - depth_arr.min() + 1e-8)
return Image.fromarray((depth_normalized * 255).astype(np.uint8))
def full_pipeline(image, prompt, scale_factor=2.0):
# 1. Super Resolução com Thera
source = np.array(image) / 255.0
target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
upscaled = thera_process(source, model_edsr, params_edsr, target_shape, do_ensemble=True)
upscaled_pil = Image.fromarray((upscaled * 255).astype(np.uint8))
# 2. Gerar Bas-Relief com SDXL Img2Img
full_prompt = f"BAS-RELIEF {prompt}, intricate carving, marble relief"
bas_relief = pipe(
prompt=full_prompt,
image=upscaled_pil,
strength=0.7,
num_inference_steps=25,
guidance_scale=7.5
).images[0]
# 3. Calcular Depth Map
inputs = feature_extractor(bas_relief, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = depth_model(**inputs)
depth = outputs.predicted_depth
depth_map = torch.nn.functional.interpolate(
depth.unsqueeze(1),
size=bas_relief.size[::-1],
mode="bicubic"
).squeeze().cpu().numpy()
return upscaled_pil, bas_relief, enhance_depth_map(depth_map)
# Interface Gradio
with gr.Blocks(title="Super Resolução + Bas-Relief") as app:
gr.Markdown("## 📈 Super Resolução + 🗿 Bas-Relief + 🗺️ Mapa de Profundidade")
with gr.Row():
with gr.Column():
img_input = gr.Image(type="pil", label="Imagem de Entrada")
prompt = gr.Textbox("ancient sculpture, marble", label="Descrição do Relevo")
scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
btn = gr.Button("Processar")
with gr.Column():
img_upscaled = gr.Image(label="Imagem Super Resolvida")
img_basrelief = gr.Image(label="Relevo Escultural")
img_depth = gr.Image(label="Mapa de Profundidade")
btn.click(
full_pipeline,
inputs=[img_input, prompt, scale],
outputs=[img_upscaled, img_basrelief, img_depth]
)
if __name__ == "__main__":
app.launch()