sculpt / app.py
ds1david's picture
fixing bugs
888435a
raw
history blame
8 kB
# app.py
import logging
import gradio as gr
import torch
import numpy as np
import jax
import pickle
from PIL import Image
from huggingface_hub import hf_hub_download
from model import build_thera
from super_resolve import process
from diffusers import StableDiffusionXLImg2ImgPipeline
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
# ================== CONFIGURAÇÃO DE LOGGING ==================
class CustomLogger:
def __init__(self, name):
self.logger = logging.getLogger(name)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
handler = logging.StreamHandler()
handler.setFormatter(formatter)
self.logger.addHandler(handler)
self.logger.setLevel(logging.INFO)
def divider(self, text=None, length=60):
if text:
# Cálculo seguro do número de '='
available_space = length - len(text) - 12 # 10 '=' + 2 espaços
if available_space < 1:
available_space = 1 # Garante pelo menos 1 '='
msg = f"\n{'=' * 10} {text.upper()} {'=' * available_space}"
else:
msg = "\n" + "=" * length
self.logger.info(msg)
def etapa(self, text):
self.logger.info(f"▶ {text}")
def success(self, text):
self.logger.info(f"✓ {text}")
def error(self, text):
self.logger.error(f"✗ {text}")
def warning(self, text):
self.logger.warning(f"⚠ {text}")
logger = CustomLogger(__name__)
# ================== CONFIGURAÇÃO DE HARDWARE ==================
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32
logger.divider("Configuração Inicial")
logger.success(f"Dispositivo detectado: {device.upper()}")
logger.success(f"Precisão numérica: {str(torch_dtype).replace('torch.', '')}")
# ================== CARREGAMENTO DE MODELOS ==================
def carregar_modelo_thera(repo_id):
try:
logger.divider(f"Carregando Modelo: {repo_id}")
model_path = hf_hub_download(repo_id=repo_id, filename="model.pkl")
with open(model_path, 'rb') as f:
check = pickle.load(f)
model = build_thera(3, check['backbone'], check['size'])
params = check['model']
logger.success(f"Modelo {repo_id} carregado")
return model, params
except Exception as e:
logger.error(f"Falha ao carregar {repo_id}: {str(e)}")
return None, None
# Carregar modelos Thera
try:
modelo_edsr, params_edsr = carregar_modelo_thera("prs-eth/thera-edsr-pro")
modelo_rdn, params_rdn = carregar_modelo_thera("prs-eth/thera-rdn-pro")
except Exception as e:
logger.error("Falha crítica no carregamento dos modelos Thera")
raise
# ================== PIPELINE DE ARTE ==================
pipe = None
modelo_profundidade = None
try:
logger.divider("Configurando Pipeline de Arte")
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch_dtype,
variant="fp16",
use_safetensors=True
).to(device)
pipe.load_lora_weights(
"KappaNeuro/bas-relief",
weight_name="BAS-RELIEF.safetensors"
)
logger.success("Pipeline SDXL configurado")
logger.etapa("Configurando Modelo de Profundidade")
processador_profundidade = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
modelo_profundidade = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
logger.success("Modelo de profundidade pronto")
except Exception as e:
logger.error(f"Erro na configuração da GPU: {str(e)}")
pipe = None
modelo_profundidade = None
# ================== FLUXO DE PROCESSAMENTO ==================
def pipeline_completo(imagem, fator_escala, modelo_escolhido, prompt_estilo):
try:
logger.divider("Novo Processamento")
# Converter entrada
if not isinstance(imagem, Image.Image):
imagem = Image.fromarray(imagem)
# ========= SUPER-RESOLUÇÃO =========
logger.etapa("Processando Super-Resolução")
modelo = modelo_edsr if modelo_escolhido == "EDSR" else modelo_rdn
params = params_edsr if modelo_escolhido == "EDSR" else params_rdn
sr_array = process(
np.array(imagem) / 255.,
modelo,
params,
(round(imagem.size[1] * fator_escala),
round(imagem.size[0] * fator_escala)),
True
)
sr_pil = Image.fromarray(np.array(sr_array)).convert("RGB")
logger.success(f"Super-Resolução: {sr_pil.size[0]}x{sr_pil.size[1]}")
# ========= ESTILO BAIXO-RELEVO =========
arte_pil = None
if pipe and modelo_profundidade:
try:
logger.etapa("Aplicando Estilo Artístico")
resultado = pipe(
prompt=f"BAS-RELIEF {prompt_estilo}, intricate marble carving, 8k ultra HD",
image=sr_pil,
strength=0.65,
num_inference_steps=30,
guidance_scale=7.5
)
arte_pil = resultado.images[0]
logger.success(f"Arte gerada: {arte_pil.size[0]}x{arte_pil.size[1]}")
except Exception as e:
logger.error(f"Falha no estilo: {str(e)}")
# ========= MAPA DE PROFUNDIDADE =========
mapa_pil = None
if arte_pil and modelo_profundidade:
try:
logger.etapa("Calculando Profundidade")
inputs = processador_profundidade(images=arte_pil, return_tensors="pt").to(device)
with torch.no_grad():
outputs = modelo_profundidade(**inputs)
depth = outputs.predicted_depth
depth = torch.nn.functional.interpolate(
depth.unsqueeze(1).float(),
size=arte_pil.size[::-1],
mode="bicubic"
).squeeze().cpu().numpy()
depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)
mapa_pil = Image.fromarray((depth * 255).astype(np.uint8))
logger.success("Mapa de profundidade calculado")
except Exception as e:
logger.error(f"Falha na profundidade: {str(e)}")
return sr_pil, arte_pil or sr_pil, mapa_pil or sr_pil
except Exception as e:
logger.error(f"Erro no pipeline: {str(e)}")
return None, None, None
# ================== INTERFACE GRADIO ==================
with gr.Blocks(title="TheraSR Art Suite", theme=gr.themes.Soft()) as app:
gr.Markdown("# 🎨 TheraSR - Super Resolução & Arte Generativa")
with gr.Row():
with gr.Column():
entrada_imagem = gr.Image(label="Imagem de Entrada", type="pil")
seletor_modelo = gr.Radio(
["EDSR", "RDN"],
value="EDSR",
label="Modelo de Super-Resolução"
)
controle_escala = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
entrada_prompt = gr.Textbox(
label="Prompt de Estilo",
value="insanely detailed ancient greek marble浮雕, 8k cinematic lighting"
)
botao_processar = gr.Button("Gerar", variant="primary")
with gr.Column():
saida_sr = gr.Image(label="Super-Resolução", show_label=True)
saida_arte = gr.Image(label="Arte em Relevo", show_label=True)
saida_profundidade = gr.Image(label="Mapa de Profundidade", show_label=True)
botao_processar.click(
pipeline_completo,
inputs=[entrada_imagem, controle_escala, seletor_modelo, entrada_prompt],
outputs=[saida_sr, saida_arte, saida_profundidade]
)
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860)