|
|
|
import logging |
|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
import jax |
|
import pickle |
|
from PIL import Image |
|
from huggingface_hub import hf_hub_download |
|
from model import build_thera |
|
from super_resolve import process |
|
from diffusers import StableDiffusionXLImg2ImgPipeline |
|
from transformers import DPTFeatureExtractor, DPTForDepthEstimation |
|
|
|
|
|
|
|
class CustomLogger: |
|
def __init__(self, name): |
|
self.logger = logging.getLogger(name) |
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') |
|
handler = logging.StreamHandler() |
|
handler.setFormatter(formatter) |
|
self.logger.addHandler(handler) |
|
self.logger.setLevel(logging.INFO) |
|
|
|
|
|
def divider(self, text=None, length=60): |
|
if text: |
|
|
|
available_space = length - len(text) - 12 |
|
if available_space < 1: |
|
available_space = 1 |
|
msg = f"\n{'=' * 10} {text.upper()} {'=' * available_space}" |
|
else: |
|
msg = "\n" + "=" * length |
|
self.logger.info(msg) |
|
|
|
|
|
def etapa(self, text): |
|
self.logger.info(f"▶ {text}") |
|
|
|
def success(self, text): |
|
self.logger.info(f"✓ {text}") |
|
|
|
def error(self, text): |
|
self.logger.error(f"✗ {text}") |
|
|
|
def warning(self, text): |
|
self.logger.warning(f"⚠ {text}") |
|
|
|
|
|
logger = CustomLogger(__name__) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
torch_dtype = torch.float16 if device == "cuda" else torch.float32 |
|
logger.divider("Configuração Inicial") |
|
logger.success(f"Dispositivo detectado: {device.upper()}") |
|
logger.success(f"Precisão numérica: {str(torch_dtype).replace('torch.', '')}") |
|
|
|
|
|
|
|
def carregar_modelo_thera(repo_id): |
|
try: |
|
logger.divider(f"Carregando Modelo: {repo_id}") |
|
model_path = hf_hub_download(repo_id=repo_id, filename="model.pkl") |
|
with open(model_path, 'rb') as f: |
|
check = pickle.load(f) |
|
model = build_thera(3, check['backbone'], check['size']) |
|
params = check['model'] |
|
logger.success(f"Modelo {repo_id} carregado") |
|
return model, params |
|
except Exception as e: |
|
logger.error(f"Falha ao carregar {repo_id}: {str(e)}") |
|
return None, None |
|
|
|
|
|
|
|
try: |
|
modelo_edsr, params_edsr = carregar_modelo_thera("prs-eth/thera-edsr-pro") |
|
modelo_rdn, params_rdn = carregar_modelo_thera("prs-eth/thera-rdn-pro") |
|
except Exception as e: |
|
logger.error("Falha crítica no carregamento dos modelos Thera") |
|
raise |
|
|
|
|
|
pipe = None |
|
modelo_profundidade = None |
|
|
|
try: |
|
logger.divider("Configurando Pipeline de Arte") |
|
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
torch_dtype=torch_dtype, |
|
variant="fp16", |
|
use_safetensors=True |
|
).to(device) |
|
|
|
pipe.load_lora_weights( |
|
"KappaNeuro/bas-relief", |
|
weight_name="BAS-RELIEF.safetensors" |
|
) |
|
logger.success("Pipeline SDXL configurado") |
|
|
|
logger.etapa("Configurando Modelo de Profundidade") |
|
processador_profundidade = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") |
|
modelo_profundidade = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device) |
|
logger.success("Modelo de profundidade pronto") |
|
|
|
except Exception as e: |
|
logger.error(f"Erro na configuração da GPU: {str(e)}") |
|
pipe = None |
|
modelo_profundidade = None |
|
|
|
|
|
|
|
def pipeline_completo(imagem, fator_escala, modelo_escolhido, prompt_estilo): |
|
try: |
|
logger.divider("Novo Processamento") |
|
|
|
|
|
if not isinstance(imagem, Image.Image): |
|
imagem = Image.fromarray(imagem) |
|
|
|
|
|
logger.etapa("Processando Super-Resolução") |
|
modelo = modelo_edsr if modelo_escolhido == "EDSR" else modelo_rdn |
|
params = params_edsr if modelo_escolhido == "EDSR" else params_rdn |
|
|
|
sr_array = process( |
|
np.array(imagem) / 255., |
|
modelo, |
|
params, |
|
(round(imagem.size[1] * fator_escala), |
|
round(imagem.size[0] * fator_escala)), |
|
True |
|
) |
|
|
|
sr_pil = Image.fromarray(np.array(sr_array)).convert("RGB") |
|
logger.success(f"Super-Resolução: {sr_pil.size[0]}x{sr_pil.size[1]}") |
|
|
|
|
|
arte_pil = None |
|
if pipe and modelo_profundidade: |
|
try: |
|
logger.etapa("Aplicando Estilo Artístico") |
|
resultado = pipe( |
|
prompt=f"BAS-RELIEF {prompt_estilo}, intricate marble carving, 8k ultra HD", |
|
image=sr_pil, |
|
strength=0.65, |
|
num_inference_steps=30, |
|
guidance_scale=7.5 |
|
) |
|
arte_pil = resultado.images[0] |
|
logger.success(f"Arte gerada: {arte_pil.size[0]}x{arte_pil.size[1]}") |
|
except Exception as e: |
|
logger.error(f"Falha no estilo: {str(e)}") |
|
|
|
|
|
mapa_pil = None |
|
if arte_pil and modelo_profundidade: |
|
try: |
|
logger.etapa("Calculando Profundidade") |
|
inputs = processador_profundidade(images=arte_pil, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
outputs = modelo_profundidade(**inputs) |
|
depth = outputs.predicted_depth |
|
|
|
depth = torch.nn.functional.interpolate( |
|
depth.unsqueeze(1).float(), |
|
size=arte_pil.size[::-1], |
|
mode="bicubic" |
|
).squeeze().cpu().numpy() |
|
|
|
depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8) |
|
mapa_pil = Image.fromarray((depth * 255).astype(np.uint8)) |
|
logger.success("Mapa de profundidade calculado") |
|
except Exception as e: |
|
logger.error(f"Falha na profundidade: {str(e)}") |
|
|
|
return sr_pil, arte_pil or sr_pil, mapa_pil or sr_pil |
|
|
|
except Exception as e: |
|
logger.error(f"Erro no pipeline: {str(e)}") |
|
return None, None, None |
|
|
|
|
|
|
|
with gr.Blocks(title="TheraSR Art Suite", theme=gr.themes.Soft()) as app: |
|
gr.Markdown("# 🎨 TheraSR - Super Resolução & Arte Generativa") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
entrada_imagem = gr.Image(label="Imagem de Entrada", type="pil") |
|
seletor_modelo = gr.Radio( |
|
["EDSR", "RDN"], |
|
value="EDSR", |
|
label="Modelo de Super-Resolução" |
|
) |
|
controle_escala = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala") |
|
entrada_prompt = gr.Textbox( |
|
label="Prompt de Estilo", |
|
value="insanely detailed ancient greek marble浮雕, 8k cinematic lighting" |
|
) |
|
botao_processar = gr.Button("Gerar", variant="primary") |
|
|
|
with gr.Column(): |
|
saida_sr = gr.Image(label="Super-Resolução", show_label=True) |
|
saida_arte = gr.Image(label="Arte em Relevo", show_label=True) |
|
saida_profundidade = gr.Image(label="Mapa de Profundidade", show_label=True) |
|
|
|
botao_processar.click( |
|
pipeline_completo, |
|
inputs=[entrada_imagem, controle_escala, seletor_modelo, entrada_prompt], |
|
outputs=[saida_sr, saida_arte, saida_profundidade] |
|
) |
|
|
|
if __name__ == "__main__": |
|
app.launch(server_name="0.0.0.0", server_port=7860) |