File size: 3,342 Bytes
98889c8 1eb87a5 98889c8 1665fe1 1eb87a5 98889c8 1eb87a5 1665fe1 1eb87a5 1665fe1 1eb87a5 1665fe1 1eb87a5 1665fe1 1eb87a5 1665fe1 1eb87a5 1665fe1 1eb87a5 1665fe1 1eb87a5 1665fe1 98889c8 1665fe1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import gradio as gr
import torch
import jax
import numpy as np
from PIL import Image
from diffusers import StableDiffusionXLImg2ImgPipeline
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
from super_resolve import process as thera_process # Assume imports do Thera
# Configurações
DEVICE = "cpu" # ou "cuda" se disponível
JAX_DEVICE = jax.devices("cpu")[0] # Usar CPU para JAX
# 1. Carregar modelos do Thera (EDSR/RDN)
# (Implementar conforme código original do Thera)
model_edsr, params_edsr = None, None # Carregar usando pickle/HF Hub
# 2. Carregar SDXL Img2Img + LoRA
print("Carregando SDXL Img2Img com LoRA...")
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float32
).to(DEVICE)
pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
# 3. Carregar modelo de profundidade
print("Carregando DPT...")
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(DEVICE)
def enhance_depth_map(depth_arr):
depth_normalized = (depth_arr - depth_arr.min()) / (depth_arr.max() - depth_arr.min() + 1e-8)
return Image.fromarray((depth_normalized * 255).astype(np.uint8))
def full_pipeline(image, prompt, scale_factor=2.0):
# 1. Super Resolução com Thera
source = np.array(image) / 255.0
target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
upscaled = thera_process(source, model_edsr, params_edsr, target_shape, do_ensemble=True)
upscaled_pil = Image.fromarray((upscaled * 255).astype(np.uint8))
# 2. Gerar Bas-Relief com SDXL Img2Img
full_prompt = f"BAS-RELIEF {prompt}, intricate carving, marble relief"
bas_relief = pipe(
prompt=full_prompt,
image=upscaled_pil,
strength=0.7,
num_inference_steps=25,
guidance_scale=7.5
).images[0]
# 3. Calcular Depth Map
inputs = feature_extractor(bas_relief, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = depth_model(**inputs)
depth = outputs.predicted_depth
depth_map = torch.nn.functional.interpolate(
depth.unsqueeze(1),
size=bas_relief.size[::-1],
mode="bicubic"
).squeeze().cpu().numpy()
return upscaled_pil, bas_relief, enhance_depth_map(depth_map)
# Interface Gradio
with gr.Blocks(title="Super Resolução + Bas-Relief") as app:
gr.Markdown("## 📈 Super Resolução + 🗿 Bas-Relief + 🗺️ Mapa de Profundidade")
with gr.Row():
with gr.Column():
img_input = gr.Image(type="pil", label="Imagem de Entrada")
prompt = gr.Textbox("ancient sculpture, marble", label="Descrição do Relevo")
scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
btn = gr.Button("Processar")
with gr.Column():
img_upscaled = gr.Image(label="Imagem Super Resolvida")
img_basrelief = gr.Image(label="Relevo Escultural")
img_depth = gr.Image(label="Mapa de Profundidade")
btn.click(
full_pipeline,
inputs=[img_input, prompt, scale],
outputs=[img_upscaled, img_basrelief, img_depth]
)
if __name__ == "__main__":
app.launch() |