ds1david commited on
Commit
f41a4a7
·
1 Parent(s): 756dcdd

fixing bugs

Browse files
Files changed (1) hide show
  1. app.py +55 -72
app.py CHANGED
@@ -4,25 +4,25 @@ import numpy as np
4
  import jax
5
  import pickle
6
  from PIL import Image
7
- from huggingface_hub import hf_hub_download
8
  from model import build_thera
9
  from super_resolve import process
10
- from diffusers import StableDiffusionXLPipeline
11
  from transformers import DPTFeatureExtractor, DPTForDepthEstimation
12
 
 
 
 
13
  # ========== Configuração do Thera ==========
14
  REPO_ID_EDSR = "prs-eth/thera-edsr-pro"
15
  REPO_ID_RDN = "prs-eth/thera-rdn-pro"
16
 
17
 
18
- # Carregar modelos Thera
19
  def load_thera_model(repo_id):
20
  model_path = hf_hub_download(repo_id=repo_id, filename="model.pkl")
21
  with open(model_path, 'rb') as fh:
22
  check = pickle.load(fh)
23
- params, backbone, size = check['model'], check['backbone'], check['size']
24
- model = build_thera(3, backbone, size)
25
- return model, params
26
 
27
 
28
  model_edsr, params_edsr = load_thera_model(REPO_ID_EDSR)
@@ -32,8 +32,7 @@ model_rdn, params_rdn = load_thera_model(REPO_ID_RDN)
32
  device = "cuda" if torch.cuda.is_available() else "cpu"
33
  torch_dtype = torch.float16 if device == "cuda" else torch.float32
34
 
35
- # Carregar modelos de geração
36
- pipe = StableDiffusionXLPipeline.from_pretrained(
37
  "stabilityai/stable-diffusion-xl-base-1.0",
38
  torch_dtype=torch_dtype
39
  ).to(device)
@@ -48,82 +47,66 @@ feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
48
  depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
49
 
50
 
51
- # ========== Funções Principais ==========
52
- def super_resolution(image, scale_factor, model_type):
53
- model = model_edsr if model_type == "EDSR" else model_rdn
54
- params = params_edsr if model_type == "EDSR" else params_rdn
55
-
56
- source = np.asarray(image) / 255.
57
- target_shape = (
58
- round(source.shape[0] * scale_factor),
59
- round(source.shape[1] * scale_factor),
60
- )
61
-
62
- output = process(source, model, params, target_shape, do_ensemble=True)
63
- return Image.fromarray(np.asarray(output))
64
-
65
-
66
- def generate_bas_relief(prompt):
67
- full_prompt = f"BAS-RELIEF {prompt}"
68
- image = pipe(
69
- prompt=full_prompt,
70
  num_inference_steps=25,
71
- guidance_scale=7.5,
72
- height=512,
73
- width=512
74
  ).images[0]
75
 
76
- inputs = feature_extractor(image, return_tensors="pt").to(device)
 
77
  with torch.no_grad():
78
  outputs = depth_model(**inputs)
79
- depth_map = outputs.predicted_depth
80
 
81
- depth_map = torch.nn.functional.interpolate(
82
- depth_map.unsqueeze(1),
83
- size=image.size[::-1],
84
  mode="bicubic"
85
  ).squeeze().cpu().numpy()
86
 
87
- depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
88
- depth_map = (depth_map * 255).astype(np.uint8)
89
 
90
- return image, Image.fromarray(depth_map)
91
 
92
 
93
  # ========== Interface Gradio ==========
94
- with gr.Blocks(title="TheraSR + Bas-Relief Generator") as app:
95
- gr.Markdown("# 🔥 TheraSR + Bas-Relief Generator")
96
- gr.Markdown("Combine aliasing-free super-resolution with artistic bas-relief generation")
97
-
98
- with gr.Tabs():
99
- with gr.TabItem("🖼 Super-Resolution"):
100
- with gr.Row():
101
- sr_input = gr.Image(label="Input Image", type="pil")
102
- sr_output = gr.Image(label="Super-Resolution Result")
103
- sr_scale = gr.Slider(1.0, 6.0, value=2.0, label="Scale Factor")
104
- sr_model = gr.Radio(["EDSR", "RDN"], value="EDSR", label="Model Type")
105
- sr_btn = gr.Button("Enhance Resolution")
106
-
107
- with gr.TabItem("🎨 Generate Bas-Relief"):
108
- with gr.Row():
109
- text_input = gr.Textbox(label="Art Prompt", placeholder="Roman soldier marble relief...")
110
- with gr.Row():
111
- gen_output = gr.Image(label="Generated Art")
112
- depth_output = gr.Image(label="Depth Map")
113
- gen_btn = gr.Button("Generate Artwork")
114
-
115
- # Event Handlers
116
- sr_btn.click(
117
- super_resolution,
118
- inputs=[sr_input, sr_scale, sr_model],
119
- outputs=sr_output
120
- )
121
-
122
- gen_btn.click(
123
- generate_bas_relief,
124
- inputs=text_input,
125
- outputs=[gen_output, depth_output]
126
  )
127
 
128
- # Configuração do Hugging Face
129
- app.launch(debug=False, share=True)
 
4
  import jax
5
  import pickle
6
  from PIL import Image
7
+ from huggingface_hub import hf_hub_download, file_download
8
  from model import build_thera
9
  from super_resolve import process
10
+ from diffusers import StableDiffusionXLImg2ImgPipeline
11
  from transformers import DPTFeatureExtractor, DPTForDepthEstimation
12
 
13
+ # Fix de compatibilidade
14
+ file_download.cached_download = file_download.hf_hub_download
15
+
16
  # ========== Configuração do Thera ==========
17
  REPO_ID_EDSR = "prs-eth/thera-edsr-pro"
18
  REPO_ID_RDN = "prs-eth/thera-rdn-pro"
19
 
20
 
 
21
  def load_thera_model(repo_id):
22
  model_path = hf_hub_download(repo_id=repo_id, filename="model.pkl")
23
  with open(model_path, 'rb') as fh:
24
  check = pickle.load(fh)
25
+ return build_thera(3, check['backbone'], check['size']), check['model']
 
 
26
 
27
 
28
  model_edsr, params_edsr = load_thera_model(REPO_ID_EDSR)
 
32
  device = "cuda" if torch.cuda.is_available() else "cpu"
33
  torch_dtype = torch.float16 if device == "cuda" else torch.float32
34
 
35
+ pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
 
36
  "stabilityai/stable-diffusion-xl-base-1.0",
37
  torch_dtype=torch_dtype
38
  ).to(device)
 
47
  depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
48
 
49
 
50
+ # ========== Fluxo Integrado ==========
51
+ def full_pipeline(image, scale_factor, model_type, style_prompt):
52
+ # 1. Super-Resolution
53
+ sr_model = model_edsr if model_type == "EDSR" else model_rdn
54
+ sr_params = params_edsr if model_type == "EDSR" else params_rdn
55
+ sr_image = process(np.array(image) / 255., sr_model, sr_params,
56
+ (round(image.size[1] * scale_factor),
57
+ round(image.size[0] * scale_factor)),
58
+ True)
59
+
60
+ # 2. Bas-Relief Style Transfer
61
+ prompt = f"BAS-RELIEF {style_prompt}, intricate carving, marble texture"
62
+ bas_relief = pipe(
63
+ prompt=prompt,
64
+ image=sr_image,
65
+ strength=0.6,
 
 
 
66
  num_inference_steps=25,
67
+ guidance_scale=7.5
 
 
68
  ).images[0]
69
 
70
+ # 3. Depth Map Estimation
71
+ inputs = feature_extractor(bas_relief, return_tensors="pt").to(device)
72
  with torch.no_grad():
73
  outputs = depth_model(**inputs)
74
+ depth = outputs.predicted_depth
75
 
76
+ depth = torch.nn.functional.interpolate(
77
+ depth.unsqueeze(1),
78
+ size=bas_relief.size[::-1],
79
  mode="bicubic"
80
  ).squeeze().cpu().numpy()
81
 
82
+ depth = (depth - depth.min()) / (depth.max() - depth.min())
83
+ depth = (depth * 255).astype(np.uint8)
84
 
85
+ return sr_image, bas_relief, Image.fromarray(depth)
86
 
87
 
88
  # ========== Interface Gradio ==========
89
+ with gr.Blocks(title="TheraSR + Bas-Relief Fusion") as app:
90
+ gr.Markdown("## 🪄 Super-Resolution Bas-Relief → Depth Map")
91
+
92
+ with gr.Row():
93
+ with gr.Column():
94
+ input_image = gr.Image(label="Input Image", type="pil")
95
+ scale = gr.Slider(1.0, 4.0, value=2.0, label="Scale Factor")
96
+ model_type = gr.Radio(["EDSR", "RDN"], value="EDSR", label="SR Model")
97
+ style_prompt = gr.Textbox(label="Style Prompt",
98
+ placeholder="marble sculpture, ancient greek style")
99
+ process_btn = gr.Button("Start Pipeline")
100
+
101
+ with gr.Column():
102
+ sr_output = gr.Image(label="Super-Resolution Result")
103
+ style_output = gr.Image(label="Bas-Relief Result")
104
+ depth_output = gr.Image(label="Depth Map")
105
+
106
+ process_btn.click(
107
+ full_pipeline,
108
+ inputs=[input_image, scale, model_type, style_prompt],
109
+ outputs=[sr_output, style_output, depth_output]
 
 
 
 
 
 
 
 
 
 
 
110
  )
111
 
112
+ app.launch(debug=False)