File size: 3,298 Bytes
12456dc
f4f4778
 
aa1deb6
f4f4778
12456dc
f4f4778
 
 
 
 
 
 
 
 
12456dc
 
f4f4778
 
 
 
12456dc
 
 
 
 
 
 
 
 
 
 
 
f4f4778
 
 
 
 
 
 
 
 
 
 
 
12456dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4f4778
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
from PIL import Image
import streamlit as st
from tensorflow.keras import utils, backend

from enhance_me import MIRNet, ZeroDCE


def get_mirnet_object() -> MIRNet:
    utils.get_file(
        "weights_lol_128.h5",
        "https://github.com/soumik12345/enhance-me/releases/download/v0.2/weights_lol_128.h5",
        cache_dir=".",
        cache_subdir="weights",
    )
    mirnet = MIRNet()
    mirnet.build_model()
    mirnet.load_weights("./weights/weights_lol_128.h5")
    return mirnet


def get_zero_dce_object(model_alias: str) -> ZeroDCE:
    utils.get_file(
        f"{model_alias}.h5",
        f"https://github.com/soumik12345/enhance-me/releases/download/v0.4/{model_alias}.h5",
        cache_dir=".",
        cache_subdir="weights",
    )
    dce = ZeroDCE()
    dce.load_weights(os.path.join("./weights", f"{model_alias}.h5"))
    return dce


def main():
    st.markdown("# Enhance Me")
    st.markdown("Made with :heart: by [geekyRakshit](http://github.com/soumik12345)")
    application = st.sidebar.selectbox(
        "Please select the application:", ("", "Low-light enhancement")
    )
    if application != "":
        if application == "Low-light enhancement":
            uploaded_file = st.sidebar.file_uploader("Select your image:")
            if uploaded_file is not None:
                original_image = Image.open(uploaded_file)
                st.image(original_image, caption="original image")
                model_option = st.sidebar.selectbox(
                    "Please select the model:",
                    (
                        "",
                        "MIRNet",
                        "Zero-DCE (dce_weights_lol_128)",
                        "Zero-DCE (dce_weights_lol_128_resize)",
                        "Zero-DCE (dce_weights_lol_256)",
                        "Zero-DCE (dce_weights_lol_256_resize)",
                        "Zero-DCE (dce_weights_unpaired_128)",
                        "Zero-DCE (dce_weights_unpaired_128_resize)",
                        "Zero-DCE (dce_weights_unpaired_256)",
                        "Zero-DCE (dce_weights_unpaired_256_resize)"
                    ),
                )
                if model_option != "":
                    if model_option == "MIRNet":
                        st.sidebar.info("Loading MIRNet...")
                        mirnet = get_mirnet_object()
                        st.sidebar.info("Done!")
                        st.sidebar.info("Processing Image...")
                        enhanced_image = mirnet.infer(original_image)
                        st.sidebar.info("Done!")
                        st.image(enhanced_image, caption="enhanced image")
                    elif "Zero-DCE" in model_option:
                        model_alias = model_option[model_option.find("(") + 1: model_option.find(")")]
                        st.sidebar.info("Loading Zero-DCE...")
                        zero_dce = get_zero_dce_object(model_alias)
                        st.sidebar.info("Done!")
                        enhanced_image = zero_dce.infer(original_image)
                        st.sidebar.info("Done!")
                        st.image(enhanced_image, caption="enhanced image")
                    backend.clear_session()


if __name__ == "__main__":
    main()