danifei commited on
Commit
545f79d
·
verified ·
1 Parent(s): d960e2d

add other files

Browse files
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import os
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import torchvision.transforms as transforms
7
+ import torchvision
8
+ import numpy as np
9
+ import yaml
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ from archs import Network_v3
13
+ from options.options import parse
14
+
15
+ path_opt = './options/test/LOLBlur.yml'
16
+
17
+ opt = parse(path_opt)
18
+
19
+ #define some auxiliary functions
20
+ pil_to_tensor = transforms.ToTensor()
21
+
22
+ # define some parameters based on the run we want to make
23
+ #selected network
24
+ network = opt['network']['name']
25
+
26
+ PATH_MODEL = opt['save']['path']
27
+
28
+ model = Network_v3(img_channel=opt['network']['img_channels'],
29
+ width=opt['network']['width'],
30
+ middle_blk_num=opt['network']['middle_blk_num'],
31
+ enc_blk_nums=opt['network']['enc_blk_nums'],
32
+ dec_blk_nums=opt['network']['dec_blk_nums'],
33
+ residual_layers=opt['network']['residual_layers'],
34
+ dilations=opt['network']['dilations'])
35
+
36
+ checkpoints = torch.load(opt['save']['best'])
37
+ # print(checkpoints)
38
+ model.load_state_dict(checkpoints['model_state_dict'])
39
+
40
+
41
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
42
+ model = model.to(device)
43
+
44
+ def load_img (filename):
45
+ img = Image.open(filename).convert("RGB")
46
+ img_tensor = pil_to_tensor(img)
47
+ return img_tensor
48
+
49
+ def process_img(image):
50
+ img = np.array(image)
51
+ img = img / 255.
52
+ img = img.astype(np.float32)
53
+ y = torch.tensor(img).permute(2,0,1).unsqueeze(0).to(device)
54
+
55
+ with torch.no_grad():
56
+ x_hat = model(y)
57
+
58
+ restored_img = x_hat.squeeze().permute(1,2,0).clamp_(0, 1).cpu().detach().numpy()
59
+ restored_img = np.clip(restored_img, 0. , 1.)
60
+
61
+ restored_img = (restored_img * 255.0).round().astype(np.uint8) # float32 to uint8
62
+ return Image.fromarray(restored_img) #(image, Image.fromarray(restored_img))
63
+
64
+ title = "Low-Light-Deblurring ✏️🖼️ 🤗"
65
+ description = ''' ## [Low Light Image deblurring enhancement](https://github.com/cidautai/Net-Low-light-Deblurring)
66
+
67
+ [Daniel Feijoo](https://github.com/danifei)
68
+
69
+ Fundación Cidaut
70
+
71
+
72
+ > **Disclaimer:** please remember this is not a product, thus, you will notice some limitations.
73
+ **This demo expects an image with some degradations.**
74
+ Due to the GPU memory limitations, the app might crash if you feed a high-resolution image (2K, 4K). <br>
75
+ The model was trained using mostly synthetic data, thus it might not work great on real-world complex images.
76
+
77
+ <br>
78
+ '''
79
+
80
+ examples = [['examples/inputs/0010.png'],
81
+ ['examples/inputs/0060.png'],
82
+ ['examples/inputs/0075.png'],
83
+ ["examples/inputs/0087.png"],
84
+ ["examples/inputs/0088.png"]]
85
+
86
+ css = """
87
+ .image-frame img, .image-container img {
88
+ width: auto;
89
+ height: auto;
90
+ max-width: none;
91
+ }
92
+ """
93
+
94
+ demo = gr.Interface(
95
+ fn = process_img,
96
+ inputs = [
97
+ gr.Image(type = 'pil', label = 'input')
98
+ ],
99
+ outputs = [gr.Image(type='pil', label = 'output')],
100
+ title = title,
101
+ description = description,
102
+ examples = examples,
103
+ css = css
104
+ )
105
+
106
+ if __name__ == '__main__':
107
+ demo.launch()
examples/inputs/0010.png ADDED
examples/inputs/0060.png ADDED
examples/inputs/0075.png ADDED
examples/inputs/0087.png ADDED
examples/inputs/0088.png ADDED
examples/results/0010.png ADDED
examples/results/0060.png ADDED
examples/results/0075.png ADDED
examples/results/0087.png ADDED
examples/results/0088.png ADDED
inputs/real/0000_0082.png ADDED
inputs/real/0002_0006.png ADDED
inputs/real/0031_0043.png ADDED
inputs/real/0032_0060.png ADDED
inputs/real/0033_0025.png ADDED
inputs/real/0034_0049.png ADDED
inputs/real/075_blur_9.png ADDED
inputs/real/188_blur_20.png ADDED
inputs/real/230_blur_21.png ADDED
inputs/real/234_blur_15.png ADDED
inputs/synthetic/0075.png ADDED
inputs/synthetic/0085.png ADDED
inputs/synthetic/0087.png ADDED
inputs/synthetic/0088.png ADDED
options/options.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import logging
4
+ import yaml
5
+ from collections import OrderedDict
6
+ try:
7
+ from yaml import CLoader as Loader, CDumper as Dumper
8
+ except ImportError:
9
+ from yaml import Loader, Dumper
10
+
11
+
12
+ def OrderedYaml():
13
+ '''yaml orderedDict support'''
14
+ _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
15
+
16
+ def dict_representer(dumper, data):
17
+ return dumper.represent_dict(data.items())
18
+
19
+ def dict_constructor(loader, node):
20
+ return OrderedDict(loader.construct_pairs(node))
21
+
22
+ Dumper.add_representer(OrderedDict, dict_representer)
23
+ Loader.add_constructor(_mapping_tag, dict_constructor)
24
+ return Loader, Dumper
25
+
26
+ #-----------------------
27
+ Loader, Dumper = OrderedYaml()
28
+
29
+ def parse(opt_path):
30
+ with open(opt_path, mode='r') as f:
31
+ opt = yaml.load(f, Loader=Loader)
32
+ return opt
33
+
34
+
35
+
36
+ if __name__ == '__main__':
37
+
38
+ path_yaml = './train/NBDN.yml'
39
+ with open(path_yaml, mode='r') as f:
40
+ opt = yaml.load(f, Loader=Loader)
41
+ opt = parse(path_yaml)
42
+ # print(opt)
43
+ print(type(opt['network']['width']))
44
+ # print(opt['gpu'])
options/predict/LOLBlur.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### network structures
2
+ network:
3
+ name: Network
4
+ img_channels: 3
5
+ width: 32
6
+ middle_blk_num: 3
7
+ enc_blk_nums: [1, 2, 3]
8
+ dec_blk_nums: [3, 1, 1]
9
+ enc_blk_nums_map: None
10
+ middle_blk_num_map: None
11
+ residual_layers: None
12
+ dilations: [1, 4, 9]
13
+ spatial: False
14
+ extra_depth_wise: True
15
+
16
+ #### save model
17
+ save:
18
+ best: ./models/bests/Network_no_attention_light_LOLBlur.pt
19
+
20
+ LOLBlur:
21
+ inputs_path: ./inputs/synthetic
22
+ results_path: ./results/synthetic
23
+ RealBlur:
24
+ inputs_path: ./inputs/real
25
+ results_path: ./results/real
options/train/All_LOL.yml ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### general settings
2
+ resume_training:
3
+ resume_training: True
4
+ resume: must
5
+ id: 2oceqxkt
6
+
7
+ #### training devices
8
+ device:
9
+ cuda: True
10
+ gpus: 0
11
+
12
+ #### datasets
13
+ datasets:
14
+ name: All_LOL
15
+ train:
16
+ train_path: /home/leadergpu/Datasets
17
+
18
+ n_workers: 4 # per GPU
19
+ batch_size_train: 12
20
+ cropsize: 256 # size you want to crop out as input sample.
21
+ flips: True
22
+ verbose: True
23
+ crop_type: Random
24
+ val:
25
+ test_path: /home/leadergpu/Datasets
26
+ batch_size_test: 1
27
+
28
+ #### network structures
29
+ network:
30
+ name: Network
31
+ img_channels: 3
32
+ width: 32
33
+ middle_blk_num: 3
34
+ enc_blk_nums: [1, 2, 3]
35
+ dec_blk_nums: [3, 1, 1]
36
+ enc_blk_nums_map: None
37
+ middle_blk_num_map: None
38
+ residual_layers: 2
39
+ dilations: [1, 4, 9]
40
+ spatial: None
41
+ extra_depth_wise: True
42
+
43
+
44
+ #### training settings: learning rate scheme, loss
45
+ train:
46
+ lr_initial: !!float 5e-4
47
+ lr_scheme: CosineAnnealing
48
+ betas: [0.9, 0.9]
49
+ epochs: 500
50
+ lr_gamma: 0.5
51
+ weight_decay: !!float 1e-3
52
+ eta_min: !!float 1e-6
53
+
54
+ pixel_criterion: l1
55
+ pixel_weight: 1.0
56
+
57
+ perceptual: True
58
+ perceptual_criterion: l1
59
+ perceptual_weight: 0.01
60
+ perceptual_reduction: mean
61
+
62
+ edge: True
63
+ edge_criterion: l2
64
+ edge_weight: 50.0
65
+ edge_reduction: mean
66
+
67
+ frequency: True
68
+ frequency_criterion: l2
69
+ frequency_weight: 0.01
70
+ frequency_reduction: mean
71
+
72
+ #### save model
73
+ save:
74
+ path: ./models/Network_all_LOL.pt
75
+ best: ./models/bests/
76
+ #### wandb:
77
+ wandb:
78
+ init: True
79
+ project: LOLBlur
80
+ entity: cidautai
81
+ name: Network_all_LOL
82
+ save_code: True
options/train/LOL.yml ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### general settings
2
+ resume_training:
3
+ resume_training: True
4
+ resume: must
5
+ id: waye95l7
6
+
7
+
8
+ #### datasets
9
+ datasets:
10
+ name: LOL
11
+ train:
12
+ train_path: /mnt/valab-datasets/LOL/train
13
+
14
+ n_workers: 4 # per GPU
15
+ batch_size_train: 4
16
+ cropsize: 256 # size you want to crop out as input sample.
17
+ flips: True
18
+ verbose: True
19
+ crop_type: Random
20
+ val:
21
+ test_path: /mnt/valab-datasets/LOL/test
22
+ cropsize: 256
23
+ batch_size_test: 1
24
+
25
+ #### network structures
26
+ network:
27
+ name: Network_v3
28
+ img_channels: 3
29
+ width: 32
30
+ middle_blk_num: 3
31
+ enc_blk_nums: [1, 2, 3]
32
+ dec_blk_nums: [3, 1, 1]
33
+ enc_blk_nums_map: None
34
+ middle_blk_num_map: None
35
+ residual_layers: 1
36
+ dilations: [1, 4]
37
+ spatial: None
38
+ extra_depth_wise: False
39
+
40
+
41
+ #### training settings: learning rate scheme, loss
42
+ train:
43
+ lr_initial: !!float 5e-4
44
+ lr_scheme: CosineAnnealing
45
+ betas: [0.9, 0.9]
46
+ epochs: 500
47
+ lr_gamma: 0.5
48
+ weight_decay: !!float 1e-3
49
+ eta_min: !!float 1e-6
50
+
51
+ pixel_criterion: l1
52
+ pixel_weight: 1.0
53
+
54
+ perceptual: True
55
+ perceptual_criterion: l1
56
+ perceptual_weight: 0.01
57
+ perceptual_reduction: mean
58
+
59
+ edge: True
60
+ edge_criterion: l2
61
+ edge_weight: 50.0
62
+ edge_reduction: mean
63
+
64
+ frequency: True
65
+ frequency_criterion: l2
66
+ frequency_weight: 0.01
67
+ frequency_reduction: mean
68
+
69
+ #### save model
70
+ save:
71
+ path: ./models/Network_v3_interpolate.pt
72
+ best: ./models/bests/
73
+ #### wandb:
74
+ wandb:
75
+ init: True
76
+ project: LOLBlur
77
+ entity: cidautai
78
+ name: Network_v3_interpolate_extraDW_LOLBlur
79
+ save_code: True
options/train/LOLBlur.yml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### general settings
2
+ resume_training:
3
+ resume_training: True
4
+ resume: must
5
+ id: qa5ramgk
6
+
7
+ #### training devices
8
+ device:
9
+ cuda: True
10
+ gpus: 0
11
+
12
+ #### datasets
13
+ datasets:
14
+ name: LOLBlur
15
+ train:
16
+ train_path: /home/leadergpu/Datasets/LOLBlur_temp/train
17
+
18
+ n_workers: 4 # per GPU
19
+ batch_size_train: 12
20
+ cropsize: 256 # size you want to crop out as input sample.
21
+ flips: True
22
+ verbose: True
23
+ crop_type: Random
24
+ val:
25
+ test_path: /home/leadergpu/Datasets/LOLBlur_temp/test
26
+ batch_size_test: 1
27
+
28
+ #### network structures
29
+ network:
30
+ name: Network
31
+ img_channels: 3
32
+ width: 32
33
+ middle_blk_num: 3
34
+ enc_blk_nums: [1, 2, 3]
35
+ dec_blk_nums: [3, 1, 1]
36
+ enc_blk_nums_map: None
37
+ middle_blk_num_map: None
38
+ residual_layers: None
39
+ dilations: [1, 4]
40
+ spatial: None
41
+ extra_depth_wise: True
42
+
43
+
44
+ #### training settings: learning rate scheme, loss
45
+ train:
46
+ lr_initial: !!float 2e-5
47
+ lr_scheme: CosineAnnealing
48
+ betas: [0.9, 0.9]
49
+ epochs: 700
50
+ lr_gamma: 0.5
51
+ weight_decay: !!float 1e-3
52
+ eta_min: !!float 1e-6
53
+
54
+ pixel_criterion: l1
55
+ pixel_weight: 1.0
56
+
57
+ perceptual: True
58
+ perceptual_criterion: l1
59
+ perceptual_weight: 0.01
60
+ perceptual_reduction: mean
61
+
62
+ edge: True
63
+ edge_criterion: l2
64
+ edge_weight: 50.0
65
+ edge_reduction: mean
66
+
67
+ frequency: True
68
+ frequency_criterion: l2
69
+ frequency_weight: 0.01
70
+ frequency_reduction: sum
71
+
72
+ #### save model
73
+ save:
74
+ path: ./models/Network_v3_new_network_noFPA_LOLBlur.pt
75
+ best: ./models/bests/
76
+ new: ./models/Network_v3_new_network_noFPA_LOLBlur_v2.pt
77
+ #### wandb:
78
+ wandb:
79
+ init: True
80
+ project: LOLBlur
81
+ entity: cidautai
82
+ name: Network_v3_new_network_noFPA_LOLBlur
83
+ save_code: True
84
+
85
+
requirements.txt ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ altair==5.3.0
3
+ annotated-types==0.7.0
4
+ anyio==4.4.0
5
+ attrs==23.2.0
6
+ certifi==2024.6.2
7
+ charset-normalizer==3.3.2
8
+ click==8.1.7
9
+ contourpy==1.2.1
10
+ cycler==0.12.1
11
+ dnspython==2.6.1
12
+ docker-pycreds==0.4.0
13
+ email_validator==2.2.0
14
+ exceptiongroup==1.2.1
15
+ fastapi==0.111.0
16
+ fastapi-cli==0.0.4
17
+ ffmpy==0.3.2
18
+ filelock==3.15.3
19
+ fonttools==4.53.0
20
+ fsspec==2024.6.0
21
+ gitdb==4.0.11
22
+ GitPython==3.1.43
23
+ gradio==4.36.1
24
+ gradio_client==1.0.1
25
+ h11==0.14.0
26
+ httpcore==1.0.5
27
+ httptools==0.6.1
28
+ httpx==0.27.0
29
+ huggingface-hub==0.23.4
30
+ idna==3.7
31
+ importlib_resources==6.4.0
32
+ Jinja2==3.1.4
33
+ jsonschema==4.22.0
34
+ jsonschema-specifications==2023.12.1
35
+ kiwisolver==1.4.5
36
+ kornia==0.7.2
37
+ kornia_rs==0.1.3
38
+ lpips==0.1.4
39
+ markdown-it-py==3.0.0
40
+ MarkupSafe==2.1.5
41
+ matplotlib==3.9.0
42
+ mdurl==0.1.2
43
+ mpmath==1.3.0
44
+ networkx==3.3
45
+ numpy==2.0.0
46
+ nvidia-cublas-cu12==12.1.3.1
47
+ nvidia-cuda-cupti-cu12==12.1.105
48
+ nvidia-cuda-nvrtc-cu12==12.1.105
49
+ nvidia-cuda-runtime-cu12==12.1.105
50
+ nvidia-cudnn-cu12==8.9.2.26
51
+ nvidia-cufft-cu12==11.0.2.54
52
+ nvidia-curand-cu12==10.3.2.106
53
+ nvidia-cusolver-cu12==11.4.5.107
54
+ nvidia-cusparse-cu12==12.1.0.106
55
+ nvidia-nccl-cu12==2.20.5
56
+ nvidia-nvjitlink-cu12==12.5.40
57
+ nvidia-nvtx-cu12==12.1.105
58
+ opencv-python==4.10.0.84
59
+ orjson==3.10.5
60
+ packaging==24.1
61
+ pandas==2.2.2
62
+ pillow==10.3.0
63
+ platformdirs==4.2.2
64
+ protobuf==5.27.1
65
+ psutil==6.0.0
66
+ ptflops==0.7.3
67
+ pydantic==2.7.4
68
+ pydantic_core==2.18.4
69
+ pydub==0.25.1
70
+ Pygments==2.18.0
71
+ pyparsing==3.1.2
72
+ python-dateutil==2.9.0.post0
73
+ python-dotenv==1.0.1
74
+ python-multipart==0.0.9
75
+ pytorch-msssim==1.0.0
76
+ pytz==2024.1
77
+ PyYAML==6.0.1
78
+ referencing==0.35.1
79
+ requests==2.32.3
80
+ rich==13.7.1
81
+ rpds-py==0.18.1
82
+ ruff==0.4.10
83
+ scipy==1.13.1
84
+ semantic-version==2.10.0
85
+ sentry-sdk==2.6.0
86
+ setproctitle==1.3.3
87
+ shellingham==1.5.4
88
+ six==1.16.0
89
+ smmap==5.0.1
90
+ sniffio==1.3.1
91
+ starlette==0.37.2
92
+ sympy==1.12.1
93
+ tomlkit==0.12.0
94
+ toolz==0.12.1
95
+ torch==2.3.1
96
+ torchvision==0.18.1
97
+ tqdm==4.66.4
98
+ triton==2.3.1
99
+ typer==0.12.3
100
+ typing_extensions==4.12.2
101
+ tzdata==2024.1
102
+ ujson==5.10.0
103
+ urllib3==2.2.2
104
+ uvicorn==0.30.1
105
+ uvloop==0.19.0
106
+ wandb==0.17.2
107
+ watchfiles==0.22.0
108
+ websockets==11.0.3