import gradio as gr import torch import jax import numpy as np from PIL import Image from diffusers import StableDiffusionXLImg2ImgPipeline from transformers import DPTFeatureExtractor, DPTForDepthEstimation from super_resolve import process as thera_process # Assume imports do Thera # Configurações DEVICE = "cpu" # ou "cuda" se disponível JAX_DEVICE = jax.devices("cpu")[0] # Usar CPU para JAX # 1. Carregar modelos do Thera (EDSR/RDN) # (Implementar conforme código original do Thera) model_edsr, params_edsr = None, None # Carregar usando pickle/HF Hub # 2. Carregar SDXL Img2Img + LoRA print("Carregando SDXL Img2Img com LoRA...") pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float32 ).to(DEVICE) pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors") # 3. Carregar modelo de profundidade print("Carregando DPT...") feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(DEVICE) def enhance_depth_map(depth_arr): depth_normalized = (depth_arr - depth_arr.min()) / (depth_arr.max() - depth_arr.min() + 1e-8) return Image.fromarray((depth_normalized * 255).astype(np.uint8)) def full_pipeline(image, prompt, scale_factor=2.0): # 1. Super Resolução com Thera source = np.array(image) / 255.0 target_shape = (int(image.height * scale_factor), int(image.width * scale_factor)) upscaled = thera_process(source, model_edsr, params_edsr, target_shape, do_ensemble=True) upscaled_pil = Image.fromarray((upscaled * 255).astype(np.uint8)) # 2. Gerar Bas-Relief com SDXL Img2Img full_prompt = f"BAS-RELIEF {prompt}, intricate carving, marble relief" bas_relief = pipe( prompt=full_prompt, image=upscaled_pil, strength=0.7, num_inference_steps=25, guidance_scale=7.5 ).images[0] # 3. Calcular Depth Map inputs = feature_extractor(bas_relief, return_tensors="pt").to(DEVICE) with torch.no_grad(): outputs = depth_model(**inputs) depth = outputs.predicted_depth depth_map = torch.nn.functional.interpolate( depth.unsqueeze(1), size=bas_relief.size[::-1], mode="bicubic" ).squeeze().cpu().numpy() return upscaled_pil, bas_relief, enhance_depth_map(depth_map) # Interface Gradio with gr.Blocks(title="Super Resolução + Bas-Relief") as app: gr.Markdown("## 📈 Super Resolução + 🗿 Bas-Relief + 🗺️ Mapa de Profundidade") with gr.Row(): with gr.Column(): img_input = gr.Image(type="pil", label="Imagem de Entrada") prompt = gr.Textbox("ancient sculpture, marble", label="Descrição do Relevo") scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala") btn = gr.Button("Processar") with gr.Column(): img_upscaled = gr.Image(label="Imagem Super Resolvida") img_basrelief = gr.Image(label="Relevo Escultural") img_depth = gr.Image(label="Mapa de Profundidade") btn.click( full_pipeline, inputs=[img_input, prompt, scale], outputs=[img_upscaled, img_basrelief, img_depth] ) if __name__ == "__main__": app.launch()