mednow commited on
Commit
4bb3f85
·
verified ·
1 Parent(s): 1136c1e

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -41
app.py CHANGED
@@ -18,32 +18,28 @@ def load_model(scale, anime=False):
18
  model.load_weights(model_path)
19
  return model
20
  except Exception as e:
21
- st.error(f"Error loading the model: {e}")
22
  return None
23
 
24
  def enhance_image(image, scale, anime):
25
- model = load_model(scale, anime=anime)
26
- if model is None:
27
- return None, None
28
-
29
  try:
 
 
 
 
30
  # Convert image to RGB if it has an alpha channel
31
  if image.mode != 'RGB':
32
  image = image.convert('RGB')
33
 
34
- # Process the image with the model
35
  sr_image = model.predict(image)
36
 
37
- # Ensure the enhanced image has the same size as the original
38
- sr_image = sr_image.resize(image.size)
39
-
40
- # Save enhanced image to buffer
41
  buffer = BytesIO()
42
  sr_image.save(buffer, format="PNG")
43
  buffer.seek(0)
44
  return sr_image, buffer
 
45
  except Exception as e:
46
- st.error(f"Error enhancing the image: {e}")
47
  return None, None
48
 
49
  def main():
@@ -53,38 +49,41 @@ def main():
53
  uploaded_image = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"])
54
 
55
  if uploaded_image is not None:
56
- image = Image.open(uploaded_image)
57
-
58
- # Anime toggle
59
- anime = st.checkbox("Anime Image", value=False)
60
-
61
- # Conditional scale options
62
- if anime:
63
- scale = "4x" # Set to 4x automatically when anime is selected
64
- else:
65
- scale = st.radio("Upscaling Factor", ["2x", "4x", "8x"], index=0)
66
-
67
- scale_value = int(scale.replace('x', ''))
68
-
69
- # Enhance button
70
- if st.button("Restore Image"):
71
- enhanced_image, buffer = enhance_image(image, scale_value, anime)
72
 
73
- if enhanced_image:
74
- # Show images side by side
75
- col1, col2 = st.columns(2)
76
- with col1:
77
- st.image(image, caption="Original Image", use_column_width=True)
78
- with col2:
79
- st.image(enhanced_image, caption="Enhanced Image", use_column_width=True)
80
 
81
- # Download button
82
- st.download_button(
83
- label="Download Enhanced Image",
84
- data=buffer,
85
- file_name="enhanced_image.png",
86
- mime="image/png"
87
- )
 
 
 
 
 
 
 
 
 
 
88
 
89
  if __name__ == "__main__":
90
  main()
 
18
  model.load_weights(model_path)
19
  return model
20
  except Exception as e:
21
+ st.error(f"Failed to load the model: {e}")
22
  return None
23
 
24
  def enhance_image(image, scale, anime):
 
 
 
 
25
  try:
26
+ model = load_model(scale, anime=anime)
27
+ if model is None:
28
+ return None, None
29
+
30
  # Convert image to RGB if it has an alpha channel
31
  if image.mode != 'RGB':
32
  image = image.convert('RGB')
33
 
 
34
  sr_image = model.predict(image)
35
 
 
 
 
 
36
  buffer = BytesIO()
37
  sr_image.save(buffer, format="PNG")
38
  buffer.seek(0)
39
  return sr_image, buffer
40
+
41
  except Exception as e:
42
+ st.error(f"An error occurred during image enhancement: {e}")
43
  return None, None
44
 
45
  def main():
 
49
  uploaded_image = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"])
50
 
51
  if uploaded_image is not None:
52
+ try:
53
+ image = Image.open(uploaded_image)
54
+
55
+ # Anime toggle
56
+ anime = st.checkbox("Anime Image", value=False)
57
+
58
+ # Conditional scale options
59
+ if anime:
60
+ scale = "4x" # Set to 4x automatically when anime is selected
61
+ else:
62
+ scale = st.radio("Upscaling Factor", ["2x", "4x", "8x"], index=0)
63
+
64
+ scale_value = int(scale.replace('x', ''))
 
 
 
65
 
66
+ # Enhance button
67
+ if st.button("Restore Image"):
68
+ enhanced_image, buffer = enhance_image(image, scale_value, anime)
 
 
 
 
69
 
70
+ if enhanced_image:
71
+ # Show images side by side
72
+ col1, col2 = st.columns(2)
73
+ with col1:
74
+ st.image(image, caption="Original Image", use_column_width=True)
75
+ with col2:
76
+ st.image(enhanced_image, caption="Enhanced Image", use_column_width=True)
77
+
78
+ # Download button
79
+ st.download_button(
80
+ label="Download Enhanced Image",
81
+ data=buffer,
82
+ file_name="enhanced_image.png",
83
+ mime="image/png"
84
+ )
85
+ except Exception as e:
86
+ st.error(f"An error occurred while processing the image: {e}")
87
 
88
  if __name__ == "__main__":
89
  main()