GeorgiosIoannouCoder commited on
Commit
74a334c
1 Parent(s): 8685fa4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +236 -0
app.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##############################################################################################################
2
+ # Filename: app.py
3
+ # Description: A Streamlit application to test our implementation of the x4 model,
4
+ # as descirbed in the paper "Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data"
5
+ ##############################################################################################################
6
+ #
7
+ # Import libraries.
8
+ #
9
+ import cv2
10
+ import numpy as np
11
+ import requests
12
+ import streamlit as st
13
+
14
+ from basicsr.archs.rrdbnet_arch import RRDBNet
15
+ from inference.real_esrgan import RealEsrGan
16
+ from io import BytesIO
17
+ from PIL import Image
18
+
19
+ ##############################################################################################################
20
+
21
+
22
+ # Function to run inference using the RealEsrGan model.
23
+ def run_inference(
24
+ uploaded_file,
25
+ model_name="REALESRGAN_x4",
26
+ output_path="inferences",
27
+ upscale=4,
28
+ extension="auto",
29
+ device=None,
30
+ gpu_id=None,
31
+ ):
32
+ try:
33
+ # Create an RRDBNet model instance.
34
+ model = RRDBNet(
35
+ num_in_ch=3,
36
+ num_out_ch=3,
37
+ num_feat=64,
38
+ num_block=23,
39
+ num_grow_ch=32,
40
+ scale=upscale,
41
+ )
42
+
43
+ # Set default model path based on the selected model name
44
+ if model_name == None:
45
+ model_path = "./models/REALESRGAN_x4.pth"
46
+ elif model_name == "REALESRGAN_x4":
47
+ model_path = "./models/REALESRGAN_x4.pth"
48
+ elif model_name == "REALESRNET_x4":
49
+ model_path = "./models/REALESRNET_x4.pth"
50
+
51
+ # Create an RealEsrGan model instance.
52
+ upsampler = RealEsrGan(
53
+ scale=upscale,
54
+ model_path=model_path,
55
+ dni_weight=None,
56
+ model=model,
57
+ pre_pad=10,
58
+ half=False,
59
+ device=device,
60
+ gpu_id=gpu_id,
61
+ )
62
+
63
+ # Process the input image.
64
+ if hasattr(
65
+ uploaded_file, "read"
66
+ ): # Check if it's a file uploaded from the local system.
67
+ img_pil = Image.open(uploaded_file)
68
+ elif uploaded_file.startswith("http"): # If it is an image URL.
69
+ response = requests.get(uploaded_file)
70
+ img_pil = Image.open(BytesIO(response.content))
71
+ else:
72
+ st.warning(
73
+ "Invalid input. Please provide either an image file or an image URL."
74
+ )
75
+ return
76
+
77
+ # Convert PIL image to OpenCV format.
78
+ img = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
79
+ # Perform super-resolution using Real-ESRGAN.
80
+ output, _ = upsampler.enhance(img, upscale=upscale)
81
+
82
+ # Determine the file extension for saving the output image.
83
+ if len(img.shape) == 3 and img.shape[2] == 4:
84
+ img_mode = "RGBA"
85
+ extension = "png"
86
+ else:
87
+ img_mode = None
88
+ if extension == "auto":
89
+ extension = "png" # Default extension for images from URL.
90
+
91
+ # Save the super resolution image
92
+ save_path = f"{output_path}/{model_name}_inference.{extension}"
93
+ cv2.imwrite(save_path, output)
94
+ except Exception as e:
95
+ st.error(e)
96
+ return save_path
97
+
98
+
99
+ ##############################################################################################################
100
+
101
+
102
+ # Function to apply local CSS.
103
+ def local_css(file_name):
104
+ with open(file_name) as f:
105
+ st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
106
+
107
+
108
+ ##############################################################################################################
109
+ # Main function to create the Streamlit web application.
110
+ def main():
111
+ try:
112
+ # Load CSS.
113
+ local_css("styles/style.css")
114
+
115
+ # Title.
116
+ title = f"""<p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 2.3rem;">
117
+ Super Upscale Resolution with Real-ESRGAN</p>"""
118
+ st.markdown(title, unsafe_allow_html=True)
119
+
120
+ # Toggle button for displaying text input or file uploader.
121
+ title = f"""<p style="font-family: monospace; color: white;">
122
+ Enter Image URL or Upload Image (checkbox):</p>"""
123
+ st.markdown(title, unsafe_allow_html=True)
124
+
125
+ use_image_url = st.checkbox(
126
+ label="Enter Image URL or Upload Image:", label_visibility="collapsed"
127
+ )
128
+
129
+ # Input for image URL or file uploader based on the checkbox state.
130
+ if use_image_url:
131
+ image_url_label = f"""
132
+ <p style="font-family: monospace; color: white;">Enter Image URL:</p>"""
133
+ st.markdown(image_url_label, unsafe_allow_html=True)
134
+
135
+ image_url = st.text_input(
136
+ label="Enter Image URL:",
137
+ value="",
138
+ label_visibility="collapsed",
139
+ )
140
+ else:
141
+ uploaded_file_label = f"""
142
+ <p style="font-family: monospace; color: white;">Upload Image:</p>"""
143
+ st.markdown(uploaded_file_label, unsafe_allow_html=True)
144
+ uploaded_file = st.file_uploader(
145
+ label="Upload Image:",
146
+ type=["jpg", "png", "jpeg"],
147
+ label_visibility="collapsed",
148
+ )
149
+
150
+ # Dropdown menu for model selection.
151
+ model_name_label = f"""
152
+ <p style="font-family: monospace; color: white;">Select Model:</p>"""
153
+ st.markdown(model_name_label, unsafe_allow_html=True)
154
+
155
+ model_name = st.selectbox(
156
+ label="Select Model:",
157
+ options=[
158
+ "REALESRGAN_x4",
159
+ "REALESRNET_x4",
160
+ ],
161
+ label_visibility="collapsed",
162
+ )
163
+
164
+ # Slider for upscale selection.
165
+ model_name_label = f"""
166
+ <p style="font-family: monospace; color: white;">Select Upscale Factor. Model works best with x4 upscale:</p>"""
167
+ st.markdown(model_name_label, unsafe_allow_html=True)
168
+
169
+ upscale = st.slider(
170
+ label="Select Upscale Factor. Model works best with x4 upscale:",
171
+ min_value=3,
172
+ max_value=10,
173
+ value=4,
174
+ step=1,
175
+ label_visibility="collapsed",
176
+ )
177
+
178
+ if not use_image_url and uploaded_file is not None:
179
+ # Image caption.
180
+ image_caption = f"""<p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 2.3rem;">
181
+ Uploaded Image:</p>"""
182
+ st.markdown(image_caption, unsafe_allow_html=True)
183
+ st.image(uploaded_file)
184
+
185
+ with st.spinner(
186
+ text="Running Inference. May take up to 3 minutes. Please be patient..."
187
+ ):
188
+ if st.button("Run Inference"):
189
+ if use_image_url and image_url != "":
190
+ result_path = run_inference(
191
+ uploaded_file=image_url,
192
+ model_name=model_name,
193
+ upscale=upscale,
194
+ )
195
+ # Image caption.
196
+ image_caption = f"""<p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 2.3rem;">
197
+ Resulting Image:</p>"""
198
+ st.markdown(image_caption, unsafe_allow_html=True)
199
+ st.image(result_path)
200
+
201
+ st.success("Inference completed!")
202
+ elif not use_image_url and uploaded_file is not None:
203
+ result_path = run_inference(
204
+ uploaded_file=uploaded_file,
205
+ model_name=model_name,
206
+ upscale=upscale,
207
+ )
208
+
209
+ # Image caption.
210
+ image_caption = f"""<p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 2.3rem;">
211
+ Resulting Image:</p>"""
212
+ st.markdown(image_caption, unsafe_allow_html=True)
213
+ st.image(result_path)
214
+
215
+ st.success("Inference completed!")
216
+ else:
217
+ st.warning("Please provide either an image file or an image URL.")
218
+
219
+ # GitHub repository of this project.
220
+ st.markdown(
221
+ f"""
222
+ <p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 1rem;">
223
+ <b>Check out our <a href="https://github.com/GeorgiosIoannouCoder/realesrgan" style="color: #FAF9F6;">GitHub repository</a></b>
224
+ </p>
225
+ """,
226
+ unsafe_allow_html=True,
227
+ )
228
+ except Exception as e:
229
+ st.error(e)
230
+
231
+
232
+ ##############################################################################################################
233
+
234
+ if __name__ == "__main__":
235
+ main()
236
+ ##############################################################################################################