ds1david commited on
Commit
65579be
·
1 Parent(s): 85a119c

fixing bugs

Browse files
Files changed (2) hide show
  1. app.py +198 -105
  2. requirements.txt +3 -1
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import torch
3
  import numpy as np
@@ -10,113 +11,205 @@ from super_resolve import process
10
  from diffusers import StableDiffusionXLImg2ImgPipeline
11
  from transformers import DPTFeatureExtractor, DPTForDepthEstimation
12
 
13
- # Fix de compatibilidade
 
 
 
 
 
14
  file_download.cached_download = file_download.hf_hub_download
15
 
16
- # ========== Configuração do Thera ==========
17
- REPO_ID_EDSR = "prs-eth/thera-edsr-pro"
18
- REPO_ID_RDN = "prs-eth/thera-rdn-pro"
19
-
20
-
21
- def load_thera_model(repo_id):
22
- model_path = hf_hub_download(repo_id=repo_id, filename="model.pkl")
23
- with open(model_path, 'rb') as fh:
24
- check = pickle.load(fh)
25
- return build_thera(3, check['backbone'], check['size']), check['model']
26
-
27
-
28
- model_edsr, params_edsr = load_thera_model(REPO_ID_EDSR)
29
- model_rdn, params_rdn = load_thera_model(REPO_ID_RDN)
30
-
31
- # ========== Configuração do SDXL + Depth ==========
32
- device = "cpu"
33
- torch_dtype = torch.float32
34
- # device = "cuda" if torch.cuda.is_available() else "cpu"
35
- # torch_dtype = torch.float16 if device == "cuda" else torch.float32
36
-
37
-
38
- pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
39
- "stabilityai/stable-diffusion-xl-base-1.0",
40
- torch_dtype=torch_dtype
41
- ).to(device)
42
-
43
- pipe.load_lora_weights(
44
- "KappaNeuro/bas-relief",
45
- weight_name="BAS-RELIEF.safetensors",
46
- peft_backend="peft"
47
- )
48
-
49
- # ========== Configuração do Modelo de Profundidade ==========
50
- depth_processor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") # Nome padronizado
51
- depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
52
-
53
- # ========== Fluxo Integrado ==========
54
- def full_pipeline(image, scale_factor, model_type, style_prompt):
55
- # 1. Super-Resolution (JAX)
56
- sr_model = model_edsr if model_type == "EDSR" else model_rdn
57
- sr_params = params_edsr if model_type == "EDSR" else params_rdn
58
-
59
- # Processar e converter para numpy array
60
- sr_jax = process(np.array(image) / 255., sr_model, sr_params,
61
- (round(image.size[1] * scale_factor),
62
- round(image.size[0] * scale_factor)),
63
- True)
64
-
65
- # Conversão crítica: JAX Array → numpy → PIL
66
- sr_np = np.asarray(sr_jax)
67
- sr_pil = Image.fromarray(sr_np)
68
-
69
- # 2. Style Transfer (PyTorch)
70
- prompt = f"BAS-RELIEF {style_prompt}, intricate carving, marble texture"
71
- bas_relief = pipe(
72
- prompt=prompt,
73
- image=sr_pil, # Usar PIL Image diretamente
74
- strength=0.6,
75
- num_inference_steps=25,
76
- guidance_scale=7.5
77
- ).images[0]
78
-
79
- # 3. Depth Map
80
- inputs = depth_processor(bas_relief, return_tensors="pt").to(device)
81
- with torch.no_grad():
82
- outputs = depth_model(**inputs)
83
- depth = outputs.predicted_depth
84
-
85
- depth = torch.nn.functional.interpolate(
86
- depth.unsqueeze(1),
87
- mode="bicubic",
88
- size=bas_relief.size[::-1]
89
- ).squeeze().cpu().numpy()
90
-
91
- depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)
92
- depth = (depth * 255).astype(np.uint8)
93
-
94
- return sr_pil, bas_relief, Image.fromarray(depth)
95
-
96
- # ========== Interface Gradio ==========
97
- with gr.Blocks(title="TheraSR + Bas-Relief Fusion") as app:
98
- gr.Markdown("## 🪄 Super-Resolution → Bas-Relief → Depth Map")
99
-
100
- with gr.Row():
101
- with gr.Column():
102
- input_image = gr.Image(label="Input Image", type="pil")
103
- scale = gr.Slider(1.0, 4.0, value=2.0, label="Scale Factor")
104
- model_type = gr.Radio(["EDSR", "RDN"], value="EDSR", label="SR Model")
105
- style_prompt = gr.Textbox(
106
- label="Style Prompt",
107
- value="insanely detailed and complex engraving relief, ultra-high definition" # <-- Alteração aqui
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  )
109
- process_btn = gr.Button("Start Pipeline")
110
-
111
- with gr.Column():
112
- sr_output = gr.Image(label="Super-Resolution Result")
113
- style_output = gr.Image(label="Bas-Relief Result")
114
- depth_output = gr.Image(label="Depth Map")
115
-
116
- process_btn.click(
117
- full_pipeline,
118
- inputs=[input_image, scale, model_type, style_prompt],
119
- outputs=[sr_output, style_output, depth_output]
 
 
 
 
 
 
 
 
 
 
 
 
120
  )
121
 
122
- app.launch(debug=False)
 
 
 
 
 
 
 
 
 
1
+ import logging
2
  import gradio as gr
3
  import torch
4
  import numpy as np
 
11
  from diffusers import StableDiffusionXLImg2ImgPipeline
12
  from transformers import DPTFeatureExtractor, DPTForDepthEstimation
13
 
14
+ # ================== CONFIGURAÇÃO INICIAL ==================
15
+ # Configurar sistema de logging
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Fix para compatibilidade do Hugging Face Hub
20
  file_download.cached_download = file_download.hf_hub_download
21
 
22
+ # ================== CONFIGURAÇÃO DE HARDWARE ==================
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ torch_dtype = torch.float16 if device == "cuda" else torch.float32
25
+ logger.info(f"Dispositivo selecionado: {device.upper()}")
26
+ logger.info(f"Precisão numérica: {str(torch_dtype).replace('torch.', '')}")
27
+
28
+
29
+ # ================== CARREGAMENTO DE MODELOS ==================
30
+ def carregar_modelo_thera(repo_id):
31
+ """Carrega modelos Thera do Hugging Face Hub"""
32
+ try:
33
+ logger.info(f"Carregando modelo Thera: {repo_id}")
34
+ caminho_modelo = hf_hub_download(repo_id=repo_id, filename="model.pkl")
35
+ with open(caminho_modelo, 'rb') as arquivo:
36
+ dados = pickle.load(arquivo)
37
+ modelo = build_thera(3, dados['backbone'], dados['size'])
38
+ parametros = dados['model']
39
+ logger.success(f"Modelo {repo_id} carregado com sucesso")
40
+ return modelo, parametros
41
+ except Exception as erro:
42
+ logger.error(f"Falha ao carregar {repo_id}: {str(erro)}")
43
+ raise
44
+
45
+
46
+ # Carregar modelos Thera
47
+ try:
48
+ logger.divider("Carregando Modelos Thera")
49
+ modelo_edsr, params_edsr = carregar_modelo_thera("prs-eth/thera-edsr-pro")
50
+ modelo_rdn, params_rdn = carregar_modelo_thera("prs-eth/thera-rdn-pro")
51
+ except Exception as erro:
52
+ logger.critical("Falha crítica no carregamento dos modelos Thera")
53
+ raise
54
+
55
+ # ================== PIPELINE DE ARTE ==================
56
+ # Configurar SDXL + LoRA
57
+ try:
58
+ logger.divider("Configurando Pipeline de Arte")
59
+ pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
60
+ "stabilityai/stable-diffusion-xl-base-1.0",
61
+ torch_dtype=torch_dtype,
62
+ variant="fp16",
63
+ use_safetensors=True
64
+ ).to(device)
65
+
66
+ pipe.load_lora_weights(
67
+ "KappaNeuro/bas-relief",
68
+ weight_name="BAS-RELIEF.safetensors",
69
+ adapter_name="bas_relief"
70
+ )
71
+ logger.success("Pipeline SDXL + LoRA configurado")
72
+ except Exception as erro:
73
+ logger.error(f"Erro no SDXL: {str(erro)}")
74
+ pipe = None
75
+
76
+ # Configurar modelo de profundidade
77
+ try:
78
+ logger.divider("Configurando Modelo de Profundidade")
79
+ processador_profundidade = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
80
+ modelo_profundidade = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
81
+ logger.success("Modelo de profundidade pronto")
82
+ except Exception as erro:
83
+ logger.error(f"Erro no modelo de profundidade: {str(erro)}")
84
+ modelo_profundidade = None
85
+
86
+
87
+ # ================== FLUXO DE PROCESSAMENTO ==================
88
+ def pipeline_completo(imagem, fator_escala, modelo_escolhido, prompt_estilo):
89
+ """Executa todo o fluxo de processamento"""
90
+ try:
91
+ logger.divider("Iniciando novo processamento")
92
+
93
+ # ========= FASE 1: SUPER-RESOLUÇÃO =========
94
+ logger.etapa("Processando Super-Resolução")
95
+ modelo_sr = modelo_edsr if modelo_escolhido == "EDSR" else modelo_rdn
96
+ parametros_sr = params_edsr if modelo_escolhido == "EDSR" else params_rdn
97
+
98
+ # Converter e validar entrada
99
+ if not isinstance(imagem, Image.Image):
100
+ logger.warning("Convertendo entrada numpy para PIL Image")
101
+ imagem = Image.fromarray(imagem)
102
+
103
+ # Processar super-resolução
104
+ imagem_sr_jax = process(
105
+ np.array(imagem) / 255.,
106
+ modelo_sr,
107
+ parametros_sr,
108
+ (round(imagem.size[1] * fator_escala),
109
+ round(imagem.size[0] * fator_escala)),
110
+ True
111
+ )
112
+
113
+ # Converter para formato compatível
114
+ imagem_sr_pil = Image.fromarray(np.array(imagem_sr_jax)).convert("RGB")
115
+ logger.success(f"Super-Resolução concluída: {imagem_sr_pil.size}")
116
+
117
+ # ========= FASE 2: ESTILO BAIXO-RELEVO =========
118
+ if device == "cpu" or not pipe:
119
+ logger.warning("GPU não disponível - Pulando estilo")
120
+ return imagem_sr_pil, None, None
121
+
122
+ logger.etapa("Aplicando Estilo Baixo-Relevo")
123
+ prompt_completo = f"BAS-RELIEF {prompt_estilo}, intricate carving, marble texture, 8k"
124
+
125
+ with torch.autocast(device_type=device.split(':')[0], dtype=torch_dtype):
126
+ imagem_estilizada = pipe(
127
+ prompt=prompt_completo,
128
+ image=imagem_sr_pil,
129
+ strength=0.7,
130
+ num_inference_steps=35,
131
+ guidance_scale=7.5,
132
+ output_type="pil"
133
+ ).images[0]
134
+
135
+ logger.success(f"Estilo aplicado: {imagem_estilizada.size}")
136
+
137
+ # ========= FASE 3: MAPA DE PROFUNDIDADE =========
138
+ logger.etapa("Gerando Mapa de Profundidade")
139
+ inputs = processador_profundidade(
140
+ images=imagem_estilizada,
141
+ return_tensors="pt"
142
+ ).to(device, dtype=torch_dtype)
143
+
144
+ with torch.no_grad(), torch.autocast(device_type=device.split(':')[0]):
145
+ outputs = modelo_profundidade(**inputs)
146
+ profundidade = outputs.predicted_depth
147
+
148
+ # Processar profundidade
149
+ profundidade = torch.nn.functional.interpolate(
150
+ profundidade.unsqueeze(1).float(), # Converter para float32
151
+ size=imagem_estilizada.size[::-1],
152
+ mode="bicubic"
153
+ ).squeeze().cpu().numpy()
154
+
155
+ # Normalizar e converter
156
+ profundidade = (profundidade - profundidade.min()) / (profundidade.max() - profundidade.min() + 1e-8)
157
+ mapa_profundidade = Image.fromarray((profundidade * 255).astype(np.uint8))
158
+
159
+ logger.success("Processamento completo")
160
+ return imagem_sr_pil, imagem_estilizada, mapa_profundidade
161
+
162
+ except Exception as erro:
163
+ logger.error(f"ERRO NO PIPELINE: {str(erro)}", exc_info=True)
164
+ return imagem_sr_pil if 'imagem_sr_pil' in locals() else None, None, None
165
+
166
+
167
+ # ================== INTERFACE GRADIO ==================
168
+ with gr.Blocks(title="TheraSR Art Suite", theme=gr.themes.Soft()) as app:
169
+ gr.Markdown("""
170
+ # 🎨 TheraSR Art Suite
171
+ **Combine super-resolução aliasing-free com geração artística de baixo-relevo**
172
+ """)
173
+
174
+ with gr.Row(variant="panel"):
175
+ with gr.Column(scale=1):
176
+ entrada_imagem = gr.Image(label="🖼 Imagem de Entrada", type="pil")
177
+ seletor_modelo = gr.Radio(
178
+ ["EDSR", "RDN"],
179
+ value="EDSR",
180
+ label="🔧 Modelo de Super-Resolução"
181
  )
182
+ controle_escala = gr.Slider(
183
+ 1.0, 6.0,
184
+ value=2.0,
185
+ step=0.1,
186
+ label="🔍 Fator de Escala"
187
+ )
188
+ entrada_prompt = gr.Textbox(
189
+ label="📝 Prompt de Estilo",
190
+ value="insanely detailed and complex engraving relief, ultra HD 8k",
191
+ placeholder="Descreva o estilo desejado..."
192
+ )
193
+ botao_processar = gr.Button("🚀 Processar Imagem", variant="primary")
194
+
195
+ with gr.Column(scale=2):
196
+ saida_sr = gr.Image(label="✨ Super-Resolução", interactive=False)
197
+ saida_arte = gr.Image(label="🖌 Arte em Baixo-Relevo", interactive=False)
198
+ saida_profundidade = gr.Image(label="🗺 Mapa de Profundidade", interactive=False)
199
+
200
+ # Configurar eventos
201
+ botao_processar.click(
202
+ fn=pipeline_completo,
203
+ inputs=[entrada_imagem, controle_escala, seletor_modelo, entrada_prompt],
204
+ outputs=[saida_sr, saida_arte, saida_profundidade]
205
  )
206
 
207
+ # ================== INICIALIZAÇÃO ==================
208
+ if __name__ == "__main__":
209
+ app.launch(
210
+ server_name="0.0.0.0",
211
+ server_port=7860,
212
+ show_error=True,
213
+ share=False,
214
+ debug=False
215
+ )
requirements.txt CHANGED
@@ -2,6 +2,7 @@
2
 
3
  ConfigArgParse==1.7
4
  Pillow==10.0.0
 
5
  chex==0.1.7
6
  diffusers
7
  einops==0.6.1
@@ -33,4 +34,5 @@ torch
33
  torchvision
34
  tqdm==4.65.0
35
  transformers
36
- wandb
 
 
2
 
3
  ConfigArgParse==1.7
4
  Pillow==10.0.0
5
+ accelerate==0.25.0
6
  chex==0.1.7
7
  diffusers
8
  einops==0.6.1
 
34
  torchvision
35
  tqdm==4.65.0
36
  transformers
37
+ wandb
38
+ xformers==0.0.23