File size: 3,342 Bytes
98889c8
 
1eb87a5
98889c8
1665fe1
1eb87a5
98889c8
1eb87a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1665fe1
1eb87a5
 
1665fe1
1eb87a5
 
1665fe1
1eb87a5
 
 
 
 
1665fe1
1eb87a5
1665fe1
 
1eb87a5
 
 
1665fe1
 
 
1eb87a5
 
 
 
1665fe1
 
1eb87a5
 
 
 
 
 
 
 
1665fe1
98889c8
 
1665fe1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gradio as gr
import torch
import jax
import numpy as np
from PIL import Image
from diffusers import StableDiffusionXLImg2ImgPipeline
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
from super_resolve import process as thera_process  # Assume imports do Thera

# Configurações
DEVICE = "cpu"  # ou "cuda" se disponível
JAX_DEVICE = jax.devices("cpu")[0]  # Usar CPU para JAX

# 1. Carregar modelos do Thera (EDSR/RDN)
# (Implementar conforme código original do Thera)
model_edsr, params_edsr = None, None  # Carregar usando pickle/HF Hub

# 2. Carregar SDXL Img2Img + LoRA
print("Carregando SDXL Img2Img com LoRA...")
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float32
).to(DEVICE)
pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")

# 3. Carregar modelo de profundidade
print("Carregando DPT...")
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(DEVICE)


def enhance_depth_map(depth_arr):
    depth_normalized = (depth_arr - depth_arr.min()) / (depth_arr.max() - depth_arr.min() + 1e-8)
    return Image.fromarray((depth_normalized * 255).astype(np.uint8))


def full_pipeline(image, prompt, scale_factor=2.0):
    # 1. Super Resolução com Thera
    source = np.array(image) / 255.0
    target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
    upscaled = thera_process(source, model_edsr, params_edsr, target_shape, do_ensemble=True)
    upscaled_pil = Image.fromarray((upscaled * 255).astype(np.uint8))

    # 2. Gerar Bas-Relief com SDXL Img2Img
    full_prompt = f"BAS-RELIEF {prompt}, intricate carving, marble relief"
    bas_relief = pipe(
        prompt=full_prompt,
        image=upscaled_pil,
        strength=0.7,
        num_inference_steps=25,
        guidance_scale=7.5
    ).images[0]

    # 3. Calcular Depth Map
    inputs = feature_extractor(bas_relief, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        outputs = depth_model(**inputs)
        depth = outputs.predicted_depth

    depth_map = torch.nn.functional.interpolate(
        depth.unsqueeze(1),
        size=bas_relief.size[::-1],
        mode="bicubic"
    ).squeeze().cpu().numpy()

    return upscaled_pil, bas_relief, enhance_depth_map(depth_map)


# Interface Gradio
with gr.Blocks(title="Super Resolução + Bas-Relief") as app:
    gr.Markdown("## 📈 Super Resolução + 🗿 Bas-Relief + 🗺️ Mapa de Profundidade")

    with gr.Row():
        with gr.Column():
            img_input = gr.Image(type="pil", label="Imagem de Entrada")
            prompt = gr.Textbox("ancient sculpture, marble", label="Descrição do Relevo")
            scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
            btn = gr.Button("Processar")

        with gr.Column():
            img_upscaled = gr.Image(label="Imagem Super Resolvida")
            img_basrelief = gr.Image(label="Relevo Escultural")
            img_depth = gr.Image(label="Mapa de Profundidade")

    btn.click(
        full_pipeline,
        inputs=[img_input, prompt, scale],
        outputs=[img_upscaled, img_basrelief, img_depth]
    )

if __name__ == "__main__":
    app.launch()