File size: 8,003 Bytes
357df1b 65579be 98889c8 19a6d73 98889c8 b82dc7d 19a6d73 b82dc7d 357df1b a7111d1 b82dc7d f41a4a7 b82dc7d 65579be 357df1b f41a4a7 65579be 357df1b 65579be 357df1b 65579be 357df1b 65579be 357df1b 65579be 888435a 357df1b 888435a 65579be 357df1b 65579be 357df1b 65579be 357df1b 65579be 357df1b 65579be 357df1b 65579be 357df1b 65579be 357df1b 65579be 357df1b 65579be 357df1b 65579be 357df1b 65579be 357df1b 65579be 357df1b 65579be 19a6d73 98889c8 65579be 357df1b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
# app.py
import logging
import gradio as gr
import torch
import numpy as np
import jax
import pickle
from PIL import Image
from huggingface_hub import hf_hub_download
from model import build_thera
from super_resolve import process
from diffusers import StableDiffusionXLImg2ImgPipeline
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
# ================== CONFIGURAÇÃO DE LOGGING ==================
class CustomLogger:
def __init__(self, name):
self.logger = logging.getLogger(name)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
handler = logging.StreamHandler()
handler.setFormatter(formatter)
self.logger.addHandler(handler)
self.logger.setLevel(logging.INFO)
def divider(self, text=None, length=60):
if text:
# Cálculo seguro do número de '='
available_space = length - len(text) - 12 # 10 '=' + 2 espaços
if available_space < 1:
available_space = 1 # Garante pelo menos 1 '='
msg = f"\n{'=' * 10} {text.upper()} {'=' * available_space}"
else:
msg = "\n" + "=" * length
self.logger.info(msg)
def etapa(self, text):
self.logger.info(f"▶ {text}")
def success(self, text):
self.logger.info(f"✓ {text}")
def error(self, text):
self.logger.error(f"✗ {text}")
def warning(self, text):
self.logger.warning(f"⚠ {text}")
logger = CustomLogger(__name__)
# ================== CONFIGURAÇÃO DE HARDWARE ==================
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32
logger.divider("Configuração Inicial")
logger.success(f"Dispositivo detectado: {device.upper()}")
logger.success(f"Precisão numérica: {str(torch_dtype).replace('torch.', '')}")
# ================== CARREGAMENTO DE MODELOS ==================
def carregar_modelo_thera(repo_id):
try:
logger.divider(f"Carregando Modelo: {repo_id}")
model_path = hf_hub_download(repo_id=repo_id, filename="model.pkl")
with open(model_path, 'rb') as f:
check = pickle.load(f)
model = build_thera(3, check['backbone'], check['size'])
params = check['model']
logger.success(f"Modelo {repo_id} carregado")
return model, params
except Exception as e:
logger.error(f"Falha ao carregar {repo_id}: {str(e)}")
return None, None
# Carregar modelos Thera
try:
modelo_edsr, params_edsr = carregar_modelo_thera("prs-eth/thera-edsr-pro")
modelo_rdn, params_rdn = carregar_modelo_thera("prs-eth/thera-rdn-pro")
except Exception as e:
logger.error("Falha crítica no carregamento dos modelos Thera")
raise
# ================== PIPELINE DE ARTE ==================
pipe = None
modelo_profundidade = None
try:
logger.divider("Configurando Pipeline de Arte")
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch_dtype,
variant="fp16",
use_safetensors=True
).to(device)
pipe.load_lora_weights(
"KappaNeuro/bas-relief",
weight_name="BAS-RELIEF.safetensors"
)
logger.success("Pipeline SDXL configurado")
logger.etapa("Configurando Modelo de Profundidade")
processador_profundidade = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
modelo_profundidade = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
logger.success("Modelo de profundidade pronto")
except Exception as e:
logger.error(f"Erro na configuração da GPU: {str(e)}")
pipe = None
modelo_profundidade = None
# ================== FLUXO DE PROCESSAMENTO ==================
def pipeline_completo(imagem, fator_escala, modelo_escolhido, prompt_estilo):
try:
logger.divider("Novo Processamento")
# Converter entrada
if not isinstance(imagem, Image.Image):
imagem = Image.fromarray(imagem)
# ========= SUPER-RESOLUÇÃO =========
logger.etapa("Processando Super-Resolução")
modelo = modelo_edsr if modelo_escolhido == "EDSR" else modelo_rdn
params = params_edsr if modelo_escolhido == "EDSR" else params_rdn
sr_array = process(
np.array(imagem) / 255.,
modelo,
params,
(round(imagem.size[1] * fator_escala),
round(imagem.size[0] * fator_escala)),
True
)
sr_pil = Image.fromarray(np.array(sr_array)).convert("RGB")
logger.success(f"Super-Resolução: {sr_pil.size[0]}x{sr_pil.size[1]}")
# ========= ESTILO BAIXO-RELEVO =========
arte_pil = None
if pipe and modelo_profundidade:
try:
logger.etapa("Aplicando Estilo Artístico")
resultado = pipe(
prompt=f"BAS-RELIEF {prompt_estilo}, intricate marble carving, 8k ultra HD",
image=sr_pil,
strength=0.65,
num_inference_steps=30,
guidance_scale=7.5
)
arte_pil = resultado.images[0]
logger.success(f"Arte gerada: {arte_pil.size[0]}x{arte_pil.size[1]}")
except Exception as e:
logger.error(f"Falha no estilo: {str(e)}")
# ========= MAPA DE PROFUNDIDADE =========
mapa_pil = None
if arte_pil and modelo_profundidade:
try:
logger.etapa("Calculando Profundidade")
inputs = processador_profundidade(images=arte_pil, return_tensors="pt").to(device)
with torch.no_grad():
outputs = modelo_profundidade(**inputs)
depth = outputs.predicted_depth
depth = torch.nn.functional.interpolate(
depth.unsqueeze(1).float(),
size=arte_pil.size[::-1],
mode="bicubic"
).squeeze().cpu().numpy()
depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)
mapa_pil = Image.fromarray((depth * 255).astype(np.uint8))
logger.success("Mapa de profundidade calculado")
except Exception as e:
logger.error(f"Falha na profundidade: {str(e)}")
return sr_pil, arte_pil or sr_pil, mapa_pil or sr_pil
except Exception as e:
logger.error(f"Erro no pipeline: {str(e)}")
return None, None, None
# ================== INTERFACE GRADIO ==================
with gr.Blocks(title="TheraSR Art Suite", theme=gr.themes.Soft()) as app:
gr.Markdown("# 🎨 TheraSR - Super Resolução & Arte Generativa")
with gr.Row():
with gr.Column():
entrada_imagem = gr.Image(label="Imagem de Entrada", type="pil")
seletor_modelo = gr.Radio(
["EDSR", "RDN"],
value="EDSR",
label="Modelo de Super-Resolução"
)
controle_escala = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
entrada_prompt = gr.Textbox(
label="Prompt de Estilo",
value="insanely detailed ancient greek marble浮雕, 8k cinematic lighting"
)
botao_processar = gr.Button("Gerar", variant="primary")
with gr.Column():
saida_sr = gr.Image(label="Super-Resolução", show_label=True)
saida_arte = gr.Image(label="Arte em Relevo", show_label=True)
saida_profundidade = gr.Image(label="Mapa de Profundidade", show_label=True)
botao_processar.click(
pipeline_completo,
inputs=[entrada_imagem, controle_escala, seletor_modelo, entrada_prompt],
outputs=[saida_sr, saida_arte, saida_profundidade]
)
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860) |