ds1david commited on
Commit
357df1b
·
1 Parent(s): 65579be

fixing bugs

Browse files
Files changed (1) hide show
  1. app.py +155 -150
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import logging
2
  import gradio as gr
3
  import torch
@@ -5,211 +6,215 @@ import numpy as np
5
  import jax
6
  import pickle
7
  from PIL import Image
8
- from huggingface_hub import hf_hub_download, file_download
9
  from model import build_thera
10
  from super_resolve import process
11
  from diffusers import StableDiffusionXLImg2ImgPipeline
12
  from transformers import DPTFeatureExtractor, DPTForDepthEstimation
13
 
14
- # ================== CONFIGURAÇÃO INICIAL ==================
15
- # Configurar sistema de logging
16
- logging.basicConfig(level=logging.INFO)
17
- logger = logging.getLogger(__name__)
18
 
19
- # Fix para compatibilidade do Hugging Face Hub
20
- file_download.cached_download = file_download.hf_hub_download
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # ================== CONFIGURAÇÃO DE HARDWARE ==================
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
  torch_dtype = torch.float16 if device == "cuda" else torch.float32
25
- logger.info(f"Dispositivo selecionado: {device.upper()}")
26
- logger.info(f"Precisão numérica: {str(torch_dtype).replace('torch.', '')}")
 
27
 
28
 
29
  # ================== CARREGAMENTO DE MODELOS ==================
30
  def carregar_modelo_thera(repo_id):
31
- """Carrega modelos Thera do Hugging Face Hub"""
32
  try:
33
- logger.info(f"Carregando modelo Thera: {repo_id}")
34
- caminho_modelo = hf_hub_download(repo_id=repo_id, filename="model.pkl")
35
- with open(caminho_modelo, 'rb') as arquivo:
36
- dados = pickle.load(arquivo)
37
- modelo = build_thera(3, dados['backbone'], dados['size'])
38
- parametros = dados['model']
39
- logger.success(f"Modelo {repo_id} carregado com sucesso")
40
- return modelo, parametros
41
- except Exception as erro:
42
- logger.error(f"Falha ao carregar {repo_id}: {str(erro)}")
43
- raise
44
 
45
 
46
  # Carregar modelos Thera
47
  try:
48
- logger.divider("Carregando Modelos Thera")
49
  modelo_edsr, params_edsr = carregar_modelo_thera("prs-eth/thera-edsr-pro")
50
  modelo_rdn, params_rdn = carregar_modelo_thera("prs-eth/thera-rdn-pro")
51
- except Exception as erro:
52
- logger.critical("Falha crítica no carregamento dos modelos Thera")
53
  raise
54
 
55
  # ================== PIPELINE DE ARTE ==================
56
- # Configurar SDXL + LoRA
57
- try:
58
- logger.divider("Configurando Pipeline de Arte")
59
- pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
60
- "stabilityai/stable-diffusion-xl-base-1.0",
61
- torch_dtype=torch_dtype,
62
- variant="fp16",
63
- use_safetensors=True
64
- ).to(device)
65
-
66
- pipe.load_lora_weights(
67
- "KappaNeuro/bas-relief",
68
- weight_name="BAS-RELIEF.safetensors",
69
- adapter_name="bas_relief"
70
- )
71
- logger.success("Pipeline SDXL + LoRA configurado")
72
- except Exception as erro:
73
- logger.error(f"Erro no SDXL: {str(erro)}")
74
- pipe = None
75
 
76
- # Configurar modelo de profundidade
77
- try:
78
- logger.divider("Configurando Modelo de Profundidade")
79
- processador_profundidade = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
80
- modelo_profundidade = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
81
- logger.success("Modelo de profundidade pronto")
82
- except Exception as erro:
83
- logger.error(f"Erro no modelo de profundidade: {str(erro)}")
84
- modelo_profundidade = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
 
87
  # ================== FLUXO DE PROCESSAMENTO ==================
88
  def pipeline_completo(imagem, fator_escala, modelo_escolhido, prompt_estilo):
89
- """Executa todo o fluxo de processamento"""
90
  try:
91
- logger.divider("Iniciando novo processamento")
92
 
93
- # ========= FASE 1: SUPER-RESOLUÇÃO =========
94
- logger.etapa("Processando Super-Resolução")
95
- modelo_sr = modelo_edsr if modelo_escolhido == "EDSR" else modelo_rdn
96
- parametros_sr = params_edsr if modelo_escolhido == "EDSR" else params_rdn
97
-
98
- # Converter e validar entrada
99
  if not isinstance(imagem, Image.Image):
100
- logger.warning("Convertendo entrada numpy para PIL Image")
101
  imagem = Image.fromarray(imagem)
102
 
103
- # Processar super-resolução
104
- imagem_sr_jax = process(
 
 
 
 
105
  np.array(imagem) / 255.,
106
- modelo_sr,
107
- parametros_sr,
108
  (round(imagem.size[1] * fator_escala),
109
  round(imagem.size[0] * fator_escala)),
110
  True
111
  )
112
 
113
- # Converter para formato compatível
114
- imagem_sr_pil = Image.fromarray(np.array(imagem_sr_jax)).convert("RGB")
115
- logger.success(f"Super-Resolução concluída: {imagem_sr_pil.size}")
116
-
117
- # ========= FASE 2: ESTILO BAIXO-RELEVO =========
118
- if device == "cpu" or not pipe:
119
- logger.warning("GPU não disponível - Pulando estilo")
120
- return imagem_sr_pil, None, None
121
-
122
- logger.etapa("Aplicando Estilo Baixo-Relevo")
123
- prompt_completo = f"BAS-RELIEF {prompt_estilo}, intricate carving, marble texture, 8k"
124
-
125
- with torch.autocast(device_type=device.split(':')[0], dtype=torch_dtype):
126
- imagem_estilizada = pipe(
127
- prompt=prompt_completo,
128
- image=imagem_sr_pil,
129
- strength=0.7,
130
- num_inference_steps=35,
131
- guidance_scale=7.5,
132
- output_type="pil"
133
- ).images[0]
134
-
135
- logger.success(f"Estilo aplicado: {imagem_estilizada.size}")
136
-
137
- # ========= FASE 3: MAPA DE PROFUNDIDADE =========
138
- logger.etapa("Gerando Mapa de Profundidade")
139
- inputs = processador_profundidade(
140
- images=imagem_estilizada,
141
- return_tensors="pt"
142
- ).to(device, dtype=torch_dtype)
143
-
144
- with torch.no_grad(), torch.autocast(device_type=device.split(':')[0]):
145
- outputs = modelo_profundidade(**inputs)
146
- profundidade = outputs.predicted_depth
147
-
148
- # Processar profundidade
149
- profundidade = torch.nn.functional.interpolate(
150
- profundidade.unsqueeze(1).float(), # Converter para float32
151
- size=imagem_estilizada.size[::-1],
152
- mode="bicubic"
153
- ).squeeze().cpu().numpy()
154
-
155
- # Normalizar e converter
156
- profundidade = (profundidade - profundidade.min()) / (profundidade.max() - profundidade.min() + 1e-8)
157
- mapa_profundidade = Image.fromarray((profundidade * 255).astype(np.uint8))
158
-
159
- logger.success("Processamento completo")
160
- return imagem_sr_pil, imagem_estilizada, mapa_profundidade
161
-
162
- except Exception as erro:
163
- logger.error(f"ERRO NO PIPELINE: {str(erro)}", exc_info=True)
164
- return imagem_sr_pil if 'imagem_sr_pil' in locals() else None, None, None
165
 
166
 
167
  # ================== INTERFACE GRADIO ==================
168
  with gr.Blocks(title="TheraSR Art Suite", theme=gr.themes.Soft()) as app:
169
- gr.Markdown("""
170
- # 🎨 TheraSR Art Suite
171
- **Combine super-resolução aliasing-free com geração artística de baixo-relevo**
172
- """)
173
-
174
- with gr.Row(variant="panel"):
175
- with gr.Column(scale=1):
176
- entrada_imagem = gr.Image(label="🖼 Imagem de Entrada", type="pil")
177
  seletor_modelo = gr.Radio(
178
  ["EDSR", "RDN"],
179
  value="EDSR",
180
- label="🔧 Modelo de Super-Resolução"
181
- )
182
- controle_escala = gr.Slider(
183
- 1.0, 6.0,
184
- value=2.0,
185
- step=0.1,
186
- label="🔍 Fator de Escala"
187
  )
 
188
  entrada_prompt = gr.Textbox(
189
- label="📝 Prompt de Estilo",
190
- value="insanely detailed and complex engraving relief, ultra HD 8k",
191
- placeholder="Descreva o estilo desejado..."
192
  )
193
- botao_processar = gr.Button("🚀 Processar Imagem", variant="primary")
194
 
195
- with gr.Column(scale=2):
196
- saida_sr = gr.Image(label="Super-Resolução", interactive=False)
197
- saida_arte = gr.Image(label="🖌 Arte em Baixo-Relevo", interactive=False)
198
- saida_profundidade = gr.Image(label="🗺 Mapa de Profundidade", interactive=False)
199
 
200
- # Configurar eventos
201
  botao_processar.click(
202
- fn=pipeline_completo,
203
  inputs=[entrada_imagem, controle_escala, seletor_modelo, entrada_prompt],
204
  outputs=[saida_sr, saida_arte, saida_profundidade]
205
  )
206
 
207
- # ================== INICIALIZAÇÃO ==================
208
  if __name__ == "__main__":
209
- app.launch(
210
- server_name="0.0.0.0",
211
- server_port=7860,
212
- show_error=True,
213
- share=False,
214
- debug=False
215
- )
 
1
+ # app.py
2
  import logging
3
  import gradio as gr
4
  import torch
 
6
  import jax
7
  import pickle
8
  from PIL import Image
9
+ from huggingface_hub import hf_hub_download
10
  from model import build_thera
11
  from super_resolve import process
12
  from diffusers import StableDiffusionXLImg2ImgPipeline
13
  from transformers import DPTFeatureExtractor, DPTForDepthEstimation
14
 
 
 
 
 
15
 
16
+ # ================== CONFIGURAÇÃO DE LOGGING ==================
17
+ class CustomLogger:
18
+ def __init__(self, name):
19
+ self.logger = logging.getLogger(name)
20
+ formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
21
+ handler = logging.StreamHandler()
22
+ handler.setFormatter(formatter)
23
+ self.logger.addHandler(handler)
24
+ self.logger.setLevel(logging.INFO)
25
+
26
+
27
+ def divider(self, text=None, length=60):
28
+ if text:
29
+ # Cálculo seguro do número de '='
30
+ available_space = length - len(text) - 12 # 10 '=' + 2 espaços
31
+ if available_space < 1:
32
+ available_space = 1 # Garante pelo menos 1 '='
33
+ msg = f"\n{'=' * 10} {text.upper()} {'=' * available_space}"
34
+ else:
35
+ msg = "\n" + "=" * length
36
+ self.logger.info(msg)
37
+
38
+
39
+ def etapa(self, text):
40
+ self.logger.info(f"▶ {text}")
41
+
42
+ def success(self, text):
43
+ self.logger.info(f"✓ {text}")
44
+
45
+ def error(self, text):
46
+ self.logger.error(f"✗ {text}")
47
+
48
+ def warning(self, text):
49
+ self.logger.warning(f"⚠ {text}")
50
+
51
+
52
+ logger = CustomLogger(__name__)
53
 
54
  # ================== CONFIGURAÇÃO DE HARDWARE ==================
55
  device = "cuda" if torch.cuda.is_available() else "cpu"
56
  torch_dtype = torch.float16 if device == "cuda" else torch.float32
57
+ logger.divider("Configuração Inicial")
58
+ logger.success(f"Dispositivo detectado: {device.upper()}")
59
+ logger.success(f"Precisão numérica: {str(torch_dtype).replace('torch.', '')}")
60
 
61
 
62
  # ================== CARREGAMENTO DE MODELOS ==================
63
  def carregar_modelo_thera(repo_id):
 
64
  try:
65
+ logger.divider(f"Carregando Modelo: {repo_id}")
66
+ model_path = hf_hub_download(repo_id=repo_id, filename="model.pkl")
67
+ with open(model_path, 'rb') as f:
68
+ check = pickle.load(f)
69
+ model = build_thera(3, check['backbone'], check['size'])
70
+ params = check['model']
71
+ logger.success(f"Modelo {repo_id} carregado")
72
+ return model, params
73
+ except Exception as e:
74
+ logger.error(f"Falha ao carregar {repo_id}: {str(e)}")
75
+ return None, None
76
 
77
 
78
  # Carregar modelos Thera
79
  try:
 
80
  modelo_edsr, params_edsr = carregar_modelo_thera("prs-eth/thera-edsr-pro")
81
  modelo_rdn, params_rdn = carregar_modelo_thera("prs-eth/thera-rdn-pro")
82
+ except Exception as e:
83
+ logger.error("Falha crítica no carregamento dos modelos Thera")
84
  raise
85
 
86
  # ================== PIPELINE DE ARTE ==================
87
+ pipe = None
88
+ modelo_profundidade = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ if device == "cuda":
91
+ try:
92
+ logger.divider("Configurando Pipeline de Arte")
93
+ pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
94
+ "stabilityai/stable-diffusion-xl-base-1.0",
95
+ torch_dtype=torch_dtype,
96
+ variant="fp16",
97
+ use_safetensors=True
98
+ ).to(device)
99
+
100
+ pipe.load_lora_weights(
101
+ "KappaNeuro/bas-relief",
102
+ weight_name="BAS-RELIEF.safetensors"
103
+ )
104
+ logger.success("Pipeline SDXL configurado")
105
+
106
+ logger.etapa("Configurando Modelo de Profundidade")
107
+ processador_profundidade = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
108
+ modelo_profundidade = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
109
+ logger.success("Modelo de profundidade pronto")
110
+
111
+ except Exception as e:
112
+ logger.error(f"Erro na configuração da GPU: {str(e)}")
113
+ pipe = None
114
+ modelo_profundidade = None
115
 
116
 
117
  # ================== FLUXO DE PROCESSAMENTO ==================
118
  def pipeline_completo(imagem, fator_escala, modelo_escolhido, prompt_estilo):
 
119
  try:
120
+ logger.divider("Novo Processamento")
121
 
122
+ # Converter entrada
 
 
 
 
 
123
  if not isinstance(imagem, Image.Image):
 
124
  imagem = Image.fromarray(imagem)
125
 
126
+ # ========= SUPER-RESOLUÇÃO =========
127
+ logger.etapa("Processando Super-Resolução")
128
+ modelo = modelo_edsr if modelo_escolhido == "EDSR" else modelo_rdn
129
+ params = params_edsr if modelo_escolhido == "EDSR" else params_rdn
130
+
131
+ sr_array = process(
132
  np.array(imagem) / 255.,
133
+ modelo,
134
+ params,
135
  (round(imagem.size[1] * fator_escala),
136
  round(imagem.size[0] * fator_escala)),
137
  True
138
  )
139
 
140
+ sr_pil = Image.fromarray(np.array(sr_array)).convert("RGB")
141
+ logger.success(f"Super-Resolução: {sr_pil.size[0]}x{sr_pil.size[1]}")
142
+
143
+ # ========= ESTILO BAIXO-RELEVO =========
144
+ arte_pil = None
145
+ if pipe and modelo_profundidade:
146
+ try:
147
+ logger.etapa("Aplicando Estilo Artístico")
148
+ resultado = pipe(
149
+ prompt=f"BAS-RELIEF {prompt_estilo}, intricate marble carving, 8k ultra HD",
150
+ image=sr_pil,
151
+ strength=0.65,
152
+ num_inference_steps=30,
153
+ guidance_scale=7.5
154
+ )
155
+ arte_pil = resultado.images[0]
156
+ logger.success(f"Arte gerada: {arte_pil.size[0]}x{arte_pil.size[1]}")
157
+ except Exception as e:
158
+ logger.error(f"Falha no estilo: {str(e)}")
159
+
160
+ # ========= MAPA DE PROFUNDIDADE =========
161
+ mapa_pil = None
162
+ if arte_pil and modelo_profundidade:
163
+ try:
164
+ logger.etapa("Calculando Profundidade")
165
+ inputs = processador_profundidade(images=arte_pil, return_tensors="pt").to(device)
166
+ with torch.no_grad():
167
+ outputs = modelo_profundidade(**inputs)
168
+ depth = outputs.predicted_depth
169
+
170
+ depth = torch.nn.functional.interpolate(
171
+ depth.unsqueeze(1).float(),
172
+ size=arte_pil.size[::-1],
173
+ mode="bicubic"
174
+ ).squeeze().cpu().numpy()
175
+
176
+ depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)
177
+ mapa_pil = Image.fromarray((depth * 255).astype(np.uint8))
178
+ logger.success("Mapa de profundidade calculado")
179
+ except Exception as e:
180
+ logger.error(f"Falha na profundidade: {str(e)}")
181
+
182
+ return sr_pil, arte_pil or sr_pil, mapa_pil or sr_pil
183
+
184
+ except Exception as e:
185
+ logger.error(f"Erro no pipeline: {str(e)}")
186
+ return None, None, None
 
 
 
 
 
187
 
188
 
189
  # ================== INTERFACE GRADIO ==================
190
  with gr.Blocks(title="TheraSR Art Suite", theme=gr.themes.Soft()) as app:
191
+ gr.Markdown("# 🎨 TheraSR - Super Resolução & Arte Generativa")
192
+
193
+ with gr.Row():
194
+ with gr.Column():
195
+ entrada_imagem = gr.Image(label="Imagem de Entrada", type="pil")
 
 
 
196
  seletor_modelo = gr.Radio(
197
  ["EDSR", "RDN"],
198
  value="EDSR",
199
+ label="Modelo de Super-Resolução"
 
 
 
 
 
 
200
  )
201
+ controle_escala = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
202
  entrada_prompt = gr.Textbox(
203
+ label="Prompt de Estilo",
204
+ value="insanely detailed ancient greek marble浮雕, 8k cinematic lighting"
 
205
  )
206
+ botao_processar = gr.Button("Gerar", variant="primary")
207
 
208
+ with gr.Column():
209
+ saida_sr = gr.Image(label="Super-Resolução", show_label=True)
210
+ saida_arte = gr.Image(label="Arte em Relevo", show_label=True)
211
+ saida_profundidade = gr.Image(label="Mapa de Profundidade", show_label=True)
212
 
 
213
  botao_processar.click(
214
+ pipeline_completo,
215
  inputs=[entrada_imagem, controle_escala, seletor_modelo, entrada_prompt],
216
  outputs=[saida_sr, saida_arte, saida_profundidade]
217
  )
218
 
 
219
  if __name__ == "__main__":
220
+ app.launch(server_name="0.0.0.0", server_port=7860)