New logic
Browse files- app.py +110 -59
- requirements.txt +4 -7
app.py
CHANGED
@@ -1,87 +1,138 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
-
from
|
|
|
5 |
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
|
6 |
-
from
|
|
|
7 |
|
8 |
-
|
9 |
-
|
|
|
10 |
|
11 |
-
print("Loading SDXL Base model...")
|
12 |
-
pipe = StableDiffusionXLPipeline.from_pretrained(
|
13 |
-
"stabilityai/stable-diffusion-xl-base-1.0",
|
14 |
-
torch_dtype=torch_dtype
|
15 |
-
).to(device)
|
16 |
|
17 |
-
|
18 |
-
pipe.load_lora_weights(
|
19 |
-
"KappaNeuro/bas-relief", # The HF repo with BAS-RELIEF.safetensors
|
20 |
-
weight_name="BAS-RELIEF.safetensors",
|
21 |
-
peft_backend="peft" # This is crucial
|
22 |
-
)
|
23 |
|
24 |
-
#
|
25 |
-
|
26 |
-
|
|
|
|
|
27 |
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
32 |
|
33 |
-
def enhance_depth_map(depth_arr: np.ndarray) -> Image.Image:
|
34 |
-
d_min, d_max = depth_arr.min(), depth_arr.max()
|
35 |
-
depth_stretched = (depth_arr - d_min) / (d_max - d_min + 1e-8)
|
36 |
-
depth_stretched = (depth_stretched * 255).astype(np.uint8)
|
37 |
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
-
enhancer = ImageEnhance.Sharpness(depth_pil)
|
42 |
-
depth_pil = enhancer.enhance(2.0)
|
43 |
|
44 |
-
|
45 |
|
46 |
-
def
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
prompt=full_prompt,
|
52 |
-
image=imagem,
|
53 |
-
num_inference_steps=15, # reduce if too slow
|
54 |
-
guidance_scale=7.5,
|
55 |
-
height=512, # reduce if you still get timeouts
|
56 |
-
width=512
|
57 |
-
)
|
58 |
-
image = result.images[0]
|
59 |
|
60 |
-
|
61 |
-
inputs = feature_extractor(image, return_tensors="pt").to(device)
|
62 |
with torch.no_grad():
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
predicted_depth = outputs.predicted_depth
|
65 |
|
66 |
prediction = torch.nn.functional.interpolate(
|
67 |
predicted_depth.unsqueeze(1),
|
68 |
size=image.size[::-1],
|
69 |
mode="bicubic",
|
70 |
-
align_corners=False
|
71 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
depth_map_pil = enhance_depth_map(prediction.cpu().numpy())
|
74 |
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
-
#
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
if __name__ == "__main__":
|
87 |
-
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from peft import PeftModel
|
6 |
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
|
7 |
+
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
|
8 |
+
from torchvision import transforms
|
9 |
|
10 |
+
# Configurações iniciais
|
11 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
+
TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
|
13 |
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
+
# --- Carregamento dos Modelos ---
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
# 1. Thera: Super Resolução
|
18 |
+
def load_thera_model():
|
19 |
+
# Modelo hipotético - ajuste conforme implementação real do Thera
|
20 |
+
model = torch.hub.load('prs-eth/thera', 'thera', trust_repo=True)
|
21 |
+
return model.to(DEVICE)
|
22 |
|
23 |
|
24 |
+
# 2. Depth Map com PEFT
|
25 |
+
def load_depth_model():
|
26 |
+
base_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
|
27 |
+
model = PeftModel.from_pretrained(base_model, "danube2024/dpt-peft-lora")
|
28 |
+
return model.to(DEVICE).eval()
|
29 |
|
|
|
|
|
|
|
|
|
30 |
|
31 |
+
# 3. Bas-Relief com ControlNet
|
32 |
+
def load_controlnet():
|
33 |
+
controlnet = ControlNetModel.from_pretrained(
|
34 |
+
"danube2024/controlnet-bas-relief",
|
35 |
+
torch_dtype=TORCH_DTYPE
|
36 |
+
)
|
37 |
+
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
38 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
39 |
+
controlnet=controlnet,
|
40 |
+
torch_dtype=TORCH_DTYPE
|
41 |
+
)
|
42 |
+
pipe.load_lora_weights("danube2024/bas-relief-lora")
|
43 |
+
return pipe.to(DEVICE)
|
44 |
|
|
|
|
|
45 |
|
46 |
+
# --- Processamento ---
|
47 |
|
48 |
+
def run_thera(image, model):
|
49 |
+
transform = transforms.Compose([
|
50 |
+
transforms.ToTensor(),
|
51 |
+
transforms.Normalize([0.5], [0.5])
|
52 |
+
])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
+
input_tensor = transform(image).unsqueeze(0).to(DEVICE)
|
|
|
55 |
with torch.no_grad():
|
56 |
+
output = model(input_tensor)
|
57 |
+
|
58 |
+
output_img = transforms.ToPILImage()(output.squeeze().cpu().clamp(-1, 1) * 0.5 + 0.5)
|
59 |
+
return output_img
|
60 |
+
|
61 |
+
|
62 |
+
def create_depth_map(image, model, feature_extractor):
|
63 |
+
inputs = feature_extractor(images=image, return_tensors="pt").to(DEVICE)
|
64 |
+
with torch.no_grad():
|
65 |
+
outputs = model(**inputs)
|
66 |
predicted_depth = outputs.predicted_depth
|
67 |
|
68 |
prediction = torch.nn.functional.interpolate(
|
69 |
predicted_depth.unsqueeze(1),
|
70 |
size=image.size[::-1],
|
71 |
mode="bicubic",
|
72 |
+
align_corners=False,
|
73 |
+
)
|
74 |
+
return prediction.squeeze().cpu().numpy()
|
75 |
+
|
76 |
+
|
77 |
+
def create_bas_relief(prompt, image, depth_map, pipe):
|
78 |
+
control_image = Image.fromarray((depth_map * 255).astype(np.uint8))
|
79 |
+
|
80 |
+
image = image.resize((1024, 1024))
|
81 |
+
control_image = control_image.resize((1024, 1024))
|
82 |
+
|
83 |
+
result = pipe(
|
84 |
+
prompt=prompt,
|
85 |
+
image=image,
|
86 |
+
control_image=control_image,
|
87 |
+
strength=0.8,
|
88 |
+
num_inference_steps=30
|
89 |
+
).images[0]
|
90 |
+
|
91 |
+
return result
|
92 |
+
|
93 |
+
|
94 |
+
# --- Interface Gradio ---
|
95 |
+
|
96 |
+
with gr.Blocks() as app:
|
97 |
+
gr.Markdown("# 🖼️ Super Resolução + Depth Map + Bas-Relief")
|
98 |
+
|
99 |
+
with gr.Row():
|
100 |
+
with gr.Column():
|
101 |
+
input_image = gr.Image(type="pil", label="Imagem de Entrada")
|
102 |
+
prompt = gr.Textbox("high quality bas-relief sculpture, intricate details")
|
103 |
+
submit_btn = gr.Button("Processar")
|
104 |
+
|
105 |
+
with gr.Column():
|
106 |
+
upscaled_output = gr.Image(label="Imagem Super Resolvida")
|
107 |
+
depth_output = gr.Image(label="Mapa de Profundidade")
|
108 |
+
basrelief_output = gr.Image(label="Resultado Bas-Relief")
|
109 |
|
|
|
110 |
|
111 |
+
def process(image, prompt):
|
112 |
+
# Carregar modelos
|
113 |
+
thera_model = load_thera_model()
|
114 |
+
depth_model = load_depth_model()
|
115 |
+
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
|
116 |
+
basrelief_pipe = load_controlnet()
|
117 |
|
118 |
+
# 1. Super Resolução
|
119 |
+
upscaled = run_thera(image, thera_model)
|
120 |
+
|
121 |
+
# 2. Depth Map
|
122 |
+
depth = create_depth_map(upscaled, depth_model, feature_extractor)
|
123 |
+
depth_normalized = (depth - depth.min()) / (depth.max() - depth.min())
|
124 |
+
|
125 |
+
# 3. Bas-Relief
|
126 |
+
basrelief = create_bas_relief(prompt, upscaled, depth_normalized, basrelief_pipe)
|
127 |
+
|
128 |
+
return upscaled, depth_normalized, basrelief
|
129 |
+
|
130 |
+
|
131 |
+
submit_btn.click(
|
132 |
+
process,
|
133 |
+
inputs=[input_image, prompt],
|
134 |
+
outputs=[upscaled_output, depth_output, basrelief_output]
|
135 |
+
)
|
136 |
|
137 |
if __name__ == "__main__":
|
138 |
+
app.launch()
|
requirements.txt
CHANGED
@@ -1,8 +1,5 @@
|
|
1 |
-
peft
|
2 |
-
accelerate
|
3 |
-
diffusers>=0.20.0
|
4 |
-
transformers>=4.30.0
|
5 |
-
torch
|
6 |
gradio
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
gradio
|
2 |
+
torch
|
3 |
+
peft
|
4 |
+
transformers
|
5 |
+
diffusers
|