Spaces:
Runtime error
Runtime error
import torch | |
import torch.optim | |
import model | |
import numpy as np | |
from PIL import Image | |
import streamlit as st | |
from torchvision import transforms | |
scale_factor = 1 | |
def load_model() -> torch.nn.Module: | |
DCE_net = model.enhance_net_nopool(scale_factor) | |
DCE_net.load_state_dict(torch.load("lowlight-dce-snapshot.pth", map_location=torch.device('cpu'))) | |
return DCE_net | |
def fix_lowlight(image: Image.Image) -> Image.Image: | |
DCE_net = load_model() | |
data_lowlight = np.asarray(image) / 255.0 | |
data_lowlight = torch.from_numpy(data_lowlight).float() | |
h = (data_lowlight.shape[0] // scale_factor) * scale_factor | |
w = (data_lowlight.shape[1] // scale_factor) * scale_factor | |
data_lowlight = data_lowlight[0:h, 0:w, :] | |
data_lowlight = data_lowlight.permute(2, 0, 1) | |
data_lowlight = data_lowlight.unsqueeze(0) | |
enhanced_image, _ = DCE_net(data_lowlight) | |
im = transforms.ToPILImage()(enhanced_image[0]).convert("RGB") | |
return im | |
def main(): | |
st.title("Lowlight Enhancement") | |
st.write("This is a simple lowlight enhancement app with great performance and does not require paired images to train.") | |
st.write("The model runs at 1000/11 FPS on single GPU/CPU on images with a size of 1200*900*3") | |
uploaded_file = st.file_uploader("Lowlight Image") | |
if uploaded_file: | |
data_lowlight = Image.open(uploaded_file).convert('RGB') | |
col1, col2 = st.columns(2) | |
col1.write("Original (Lowlight)") | |
col1.image(data_lowlight, caption="Lowlight Image", use_column_width=True) | |
col2.write("Enhanced") | |
with st.spinner('🧠 Enhancing...'): | |
fixed_img = fix_lowlight(data_lowlight) | |
col2.image(fixed_img, caption="Enhanced Image", use_column_width=True) | |
main() |