Spaces:
Sleeping
Sleeping
haoning.wu
commited on
Commit
•
f8ea2c9
1
Parent(s):
bf92928
Add InstructIR plugin!
Browse files- app.py +101 -5
- eval5d.yml +40 -0
- examples/211.jpg +0 -0
- examples/extreme_ironing.jpg +0 -0
- examples/frog.png +3 -0
- examples/gopro.png +3 -0
- examples/lol_748.png +3 -0
- examples/noise50.png +3 -0
- examples/sausage.jpg +0 -0
- insir_models/.ipynb_checkpoints/instructir-checkpoint.py +134 -0
- insir_models/.ipynb_checkpoints/nafnet-checkpoint.py +201 -0
- insir_models/.ipynb_checkpoints/nafnet_utils-checkpoint.py +146 -0
- insir_models/__pycache__/instructir.cpython-39.pyc +0 -0
- insir_models/__pycache__/nafnet.cpython-39.pyc +0 -0
- insir_models/__pycache__/nafnet_utils.cpython-39.pyc +0 -0
- insir_models/instructir.py +134 -0
- insir_models/nafnet.py +201 -0
- insir_models/nafnet_utils.py +146 -0
- insir_text/.ipynb_checkpoints/models-checkpoint.py +65 -0
- insir_text/__pycache__/models.cpython-39.pyc +0 -0
- insir_text/models.py +65 -0
- insir_text/sample_prompts.json +55 -0
app.py
CHANGED
@@ -1,10 +1,85 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
import requests
|
|
|
|
|
3 |
from PIL import Image
|
4 |
|
|
|
5 |
import torch
|
6 |
from transformers import AutoModelForCausalLM
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
model = AutoModelForCausalLM.from_pretrained("q-future/co-instruct-preview",
|
9 |
trust_remote_code=True,
|
10 |
torch_dtype=torch.float16,
|
@@ -15,7 +90,7 @@ def chat(message, history, image_1, image_2, image_3, image_4):
|
|
15 |
print(history)
|
16 |
if history:
|
17 |
if image_1 is not None and image_2 is None:
|
18 |
-
past_message = "USER: The image: <|image|>
|
19 |
for i in range((len(history) - 1)):
|
20 |
past_message += "USER:" +history[i][0] + " ASSISTANT:" + history[i][1] + "</s>"
|
21 |
message = past_message + "USER:" + message + " ASSISTANT:"
|
@@ -42,7 +117,7 @@ def chat(message, history, image_1, image_2, image_3, image_4):
|
|
42 |
images = [image_1, image_2, image_3, image_4]
|
43 |
else:
|
44 |
if image_1 is not None and image_2 is None:
|
45 |
-
message = "USER: The image: <|image|>
|
46 |
images = [image_1]
|
47 |
if image_1 is not None and image_2 is not None:
|
48 |
if image_3 is None:
|
@@ -58,14 +133,24 @@ def chat(message, history, image_1, image_2, image_3, image_4):
|
|
58 |
|
59 |
print(message)
|
60 |
|
61 |
-
return model.tokenizer.batch_decode(model.chat(message, images, max_new_tokens=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
|
|
|
64 |
with gr.Blocks(title="img") as demo:
|
65 |
title_markdown = ("""
|
66 |
-
|
67 |
<h1 align="center"><a href="https://github.com/Q-Future/Q-Instruct"><img src="https://github.com/Q-Future/Q-Instruct/blob/main/q_instruct_logo.png?raw=true", alt="Q-Instruct (mPLUG-Owl-2)" border="0" style="margin: 0 auto; height: 85px;" /></a> </h1>
|
68 |
<h2 align="center">Q-Instruct: Improving Low-level Visual Abilities for Multi-modality Foundation Models</h2>
|
|
|
69 |
<h5 align="center"> Please find our more accurate visual scoring demo on <a href='https://huggingface.co/spaces/teowu/OneScorer'>[OneScorer]</a>!</h2>
|
70 |
<div align="center">
|
71 |
<div style="display:flex; gap: 0.25rem;" align="center">
|
@@ -81,5 +166,16 @@ with gr.Blocks(title="img") as demo:
|
|
81 |
input_img_2 = gr.Image(type='pil', label="Image 2 (Second image)")
|
82 |
input_img_3 = gr.Image(type='pil', label="Image 3 (Third image)")
|
83 |
input_img_4 = gr.Image(type='pil', label="Image 4 (Third image)")
|
84 |
-
gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
demo.launch(share=True)
|
|
|
1 |
+
import os, yaml
|
2 |
import gradio as gr
|
3 |
import requests
|
4 |
+
import argparse
|
5 |
+
|
6 |
from PIL import Image
|
7 |
|
8 |
+
import numpy as np
|
9 |
import torch
|
10 |
from transformers import AutoModelForCausalLM
|
11 |
|
12 |
+
from huggingface_hub import hf_hub_download
|
13 |
+
|
14 |
+
|
15 |
+
## InstructIR Plugin ##
|
16 |
+
from insir_models import instructir
|
17 |
+
from insir_text.models import LanguageModel, LMHead
|
18 |
+
|
19 |
+
hf_hub_download(repo_id="marcosv/InstructIR", filename="im_instructir-7d.pt", local_dir="./")
|
20 |
+
hf_hub_download(repo_id="marcosv/InstructIR", filename="lm_instructir-7d.pt", local_dir="./")
|
21 |
+
|
22 |
+
CONFIG = "eval5d.yml"
|
23 |
+
LM_MODEL = "lm_instructir-7d.pt"
|
24 |
+
MODEL_NAME = "im_instructir-7d.pt"
|
25 |
+
|
26 |
+
def dict2namespace(config):
|
27 |
+
namespace = argparse.Namespace()
|
28 |
+
for key, value in config.items():
|
29 |
+
if isinstance(value, dict):
|
30 |
+
new_value = dict2namespace(value)
|
31 |
+
else:
|
32 |
+
new_value = value
|
33 |
+
setattr(namespace, key, new_value)
|
34 |
+
return namespace
|
35 |
+
|
36 |
+
|
37 |
+
# parse config file
|
38 |
+
with open(os.path.join(CONFIG), "r") as f:
|
39 |
+
config = yaml.safe_load(f)
|
40 |
+
|
41 |
+
cfg = dict2namespace(config)
|
42 |
+
|
43 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
44 |
+
ir_model = instructir.create_model(input_channels =cfg.model.in_ch, width=cfg.model.width, enc_blks = cfg.model.enc_blks,
|
45 |
+
middle_blk_num = cfg.model.middle_blk_num, dec_blks = cfg.model.dec_blks, txtdim=cfg.model.textdim)
|
46 |
+
ir_model = ir_model.to(device)
|
47 |
+
print ("IMAGE MODEL CKPT:", MODEL_NAME)
|
48 |
+
ir_model.load_state_dict(torch.load(MODEL_NAME, map_location="cpu"), strict=True)
|
49 |
+
|
50 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
51 |
+
LMODEL = cfg.llm.model
|
52 |
+
language_model = LanguageModel(model=LMODEL)
|
53 |
+
lm_head = LMHead(embedding_dim=cfg.llm.model_dim, hidden_dim=cfg.llm.embd_dim, num_classes=cfg.llm.nclasses)
|
54 |
+
lm_head = lm_head.to(device)
|
55 |
+
|
56 |
+
print("LMHEAD MODEL CKPT:", LM_MODEL)
|
57 |
+
lm_head.load_state_dict(torch.load(LM_MODEL, map_location="cpu"), strict=True)
|
58 |
+
|
59 |
+
def process_img(image, prompt=None):
|
60 |
+
if prompt is None:
|
61 |
+
prompt = chat("How to improve the quality of the image?", [], image, None, None, None)
|
62 |
+
prompt += "Please help me improve its quality!"
|
63 |
+
print(prompt)
|
64 |
+
img = np.array(image)
|
65 |
+
img = img / 255.
|
66 |
+
img = img.astype(np.float32)
|
67 |
+
y = torch.tensor(img).permute(2,0,1).unsqueeze(0).to(device)
|
68 |
+
|
69 |
+
lm_embd = language_model(prompt)
|
70 |
+
lm_embd = lm_embd.to(device)
|
71 |
+
|
72 |
+
with torch.no_grad():
|
73 |
+
text_embd, deg_pred = lm_head(lm_embd)
|
74 |
+
x_hat = ir_model(y, text_embd)
|
75 |
+
|
76 |
+
restored_img = x_hat.squeeze().permute(1,2,0).clamp_(0, 1).cpu().detach().numpy()
|
77 |
+
restored_img = np.clip(restored_img, 0. , 1.)
|
78 |
+
|
79 |
+
restored_img = (restored_img * 255.0).round().astype(np.uint8) # float32 to uint8
|
80 |
+
return Image.fromarray(restored_img) #(image, Image.fromarray(restored_img))
|
81 |
+
|
82 |
+
## InstructIR Plugin ##
|
83 |
model = AutoModelForCausalLM.from_pretrained("q-future/co-instruct-preview",
|
84 |
trust_remote_code=True,
|
85 |
torch_dtype=torch.float16,
|
|
|
90 |
print(history)
|
91 |
if history:
|
92 |
if image_1 is not None and image_2 is None:
|
93 |
+
past_message = "USER: The input image: <|image|>" + history[0][0] + " ASSISTANT:" + history[0][1]
|
94 |
for i in range((len(history) - 1)):
|
95 |
past_message += "USER:" +history[i][0] + " ASSISTANT:" + history[i][1] + "</s>"
|
96 |
message = past_message + "USER:" + message + " ASSISTANT:"
|
|
|
117 |
images = [image_1, image_2, image_3, image_4]
|
118 |
else:
|
119 |
if image_1 is not None and image_2 is None:
|
120 |
+
message = "USER: The input image: <|image|>" + message + " ASSISTANT:"
|
121 |
images = [image_1]
|
122 |
if image_1 is not None and image_2 is not None:
|
123 |
if image_3 is None:
|
|
|
133 |
|
134 |
print(message)
|
135 |
|
136 |
+
return model.tokenizer.batch_decode(model.chat(message, images, max_new_tokens=600).clamp(0, 100000))[0].split("ASSISTANT:")[-1]
|
137 |
+
|
138 |
+
#### Image,Prompts examples
|
139 |
+
examples = [
|
140 |
+
["Which part of the image is relatively clearer, the upper part or the lower part? Please analyze in details.", Image.open("examples/sausage.jpg"), None],
|
141 |
+
["Which image is noisy, and which one is with motion blur? Please analyze in details.", Image.open("examples/211.jpg"), Image.open("examples/frog.png")],
|
142 |
+
["What is the problem in this image, and how to fix it? Please answer my questions one by one.", Image.open("examples/lol_748.png"), None],
|
143 |
+
]
|
144 |
+
|
145 |
|
146 |
|
147 |
+
title = "Q-Instruct🧑🏫"
|
148 |
with gr.Blocks(title="img") as demo:
|
149 |
title_markdown = ("""
|
150 |
+
|
151 |
<h1 align="center"><a href="https://github.com/Q-Future/Q-Instruct"><img src="https://github.com/Q-Future/Q-Instruct/blob/main/q_instruct_logo.png?raw=true", alt="Q-Instruct (mPLUG-Owl-2)" border="0" style="margin: 0 auto; height: 85px;" /></a> </h1>
|
152 |
<h2 align="center">Q-Instruct: Improving Low-level Visual Abilities for Multi-modality Foundation Models</h2>
|
153 |
+
<div align="center">Super Version of Q-Instruct with Multi-image (up to 4, same as GPT-4V) Support! We also support <a href='https://huggingface.co/marcosv/InstructIR'>InstructIR</a> as PLUGIN!</div>
|
154 |
<h5 align="center"> Please find our more accurate visual scoring demo on <a href='https://huggingface.co/spaces/teowu/OneScorer'>[OneScorer]</a>!</h2>
|
155 |
<div align="center">
|
156 |
<div style="display:flex; gap: 0.25rem;" align="center">
|
|
|
166 |
input_img_2 = gr.Image(type='pil', label="Image 2 (Second image)")
|
167 |
input_img_3 = gr.Image(type='pil', label="Image 3 (Third image)")
|
168 |
input_img_4 = gr.Image(type='pil', label="Image 4 (Third image)")
|
169 |
+
with gr.Row():
|
170 |
+
with gr.Column(scale=2):
|
171 |
+
gr.ChatInterface(fn = chat, additional_inputs=[input_img_1, input_img_2, input_img_3, input_img_4], examples=examples)
|
172 |
+
with gr.Column(scale=1):
|
173 |
+
input_image_ir = gr.Image(type="pil", label="Image for Auto Restoration")
|
174 |
+
output_image_ir = gr.Image(type="pil", label="Output of Auto Restoration")
|
175 |
+
gr.Interface(
|
176 |
+
fn=process_img,
|
177 |
+
inputs=[input_image_ir],
|
178 |
+
outputs=[output_image_ir],
|
179 |
+
examples=[Image.open("examples/gopro.png"), Image.open("examples/noise50.png"), Image.open("examples/lol_748.png")],
|
180 |
+
)
|
181 |
demo.launch(share=True)
|
eval5d.yml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
llm:
|
2 |
+
model: 'TaylorAI/bge-micro-v2' # See Paper Sec. 3.2 and Appendix
|
3 |
+
model_dim: 384
|
4 |
+
embd_dim: 256
|
5 |
+
nclasses: 7 # noise, blur, rain, haze, lol, enhancement, upsampling (Paper Sec. 4.3)
|
6 |
+
weights: False
|
7 |
+
|
8 |
+
model:
|
9 |
+
arch: "instructir"
|
10 |
+
use_text: True
|
11 |
+
in_ch: 3
|
12 |
+
out_ch: 3
|
13 |
+
width : 32
|
14 |
+
enc_blks: [2, 2, 4, 8]
|
15 |
+
middle_blk_num: 4
|
16 |
+
dec_blks: [2, 2, 2, 2]
|
17 |
+
textdim: 256
|
18 |
+
weights: False
|
19 |
+
|
20 |
+
test:
|
21 |
+
batch_size: 1
|
22 |
+
num_workers: 3
|
23 |
+
|
24 |
+
dn_datapath: "data/denoising_testsets/"
|
25 |
+
dn_datasets: ["CBSD68", "urban100", "Kodak24", "McMaster"]
|
26 |
+
dn_sigmas: [15, 25, 50]
|
27 |
+
|
28 |
+
rain_targets: ["data/Rain/rain_test/Rain100L/target/"]
|
29 |
+
rain_inputs: ["data/Rain/rain_test/Rain100L/input/"]
|
30 |
+
|
31 |
+
haze_targets: "data/SOTS-OUT/GT/"
|
32 |
+
haze_inputs : "data/SOTS-OUT/IN/"
|
33 |
+
|
34 |
+
lol_targets: "data/LOL/eval15/high/"
|
35 |
+
lol_inputs : "data/LOL/eval15/low/"
|
36 |
+
|
37 |
+
gopro_targets: "data/gopro_test/GoPro/target/"
|
38 |
+
gopro_inputs: "data/gopro_test/GoPro/input/"
|
39 |
+
|
40 |
+
|
examples/211.jpg
CHANGED
Git LFS Details
|
examples/extreme_ironing.jpg
CHANGED
Git LFS Details
|
examples/frog.png
ADDED
Git LFS Details
|
examples/gopro.png
ADDED
Git LFS Details
|
examples/lol_748.png
ADDED
Git LFS Details
|
examples/noise50.png
ADDED
Git LFS Details
|
examples/sausage.jpg
CHANGED
Git LFS Details
|
insir_models/.ipynb_checkpoints/instructir-checkpoint.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.nn import init as init
|
6 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
7 |
+
|
8 |
+
from insir_models.nafnet_utils import Local_Base, LayerNorm2d
|
9 |
+
from insir_models.nafnet import SimpleGate, NAFBlock
|
10 |
+
|
11 |
+
|
12 |
+
class ICB(nn.Module):
|
13 |
+
"""
|
14 |
+
Instruction Condition Block (ICB)
|
15 |
+
Paper Section 3.3
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, feature_dim, text_dim=768):
|
19 |
+
super(ICB, self).__init__()
|
20 |
+
self.fc = nn.Linear(text_dim, feature_dim)
|
21 |
+
self.block = NAFBlock(feature_dim)
|
22 |
+
self.beta = nn.Parameter(torch.zeros((1, feature_dim, 1, 1)), requires_grad=True)
|
23 |
+
self.gamma = nn.Parameter(torch.zeros((1, feature_dim, 1, 1)), requires_grad=True)
|
24 |
+
|
25 |
+
def forward(self, x, text_embedding):
|
26 |
+
gating_factors = torch.sigmoid(self.fc(text_embedding))
|
27 |
+
gating_factors = gating_factors.unsqueeze(-1).unsqueeze(-1)
|
28 |
+
|
29 |
+
f = x * self.gamma + self.beta # 1) learned feature scaling/modulation
|
30 |
+
f = f * gating_factors # 2) (soft) feature routing based on text
|
31 |
+
f = self.block(f) # 3) block feature enhancement
|
32 |
+
return f + x
|
33 |
+
|
34 |
+
|
35 |
+
class InstructIR(nn.Module):
|
36 |
+
"""
|
37 |
+
InstructIR model using NAFNet (ECCV 2022) as backbone.
|
38 |
+
The model takes as input an RGB image and a text embedding (encoded instruction).
|
39 |
+
Described in Paper Section 3.3
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[], txtdim=768):
|
43 |
+
super().__init__()
|
44 |
+
|
45 |
+
self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
|
46 |
+
bias=True)
|
47 |
+
self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
|
48 |
+
bias=True)
|
49 |
+
|
50 |
+
self.encoders = nn.ModuleList()
|
51 |
+
self.decoders = nn.ModuleList()
|
52 |
+
self.middle_blks = nn.ModuleList()
|
53 |
+
self.ups = nn.ModuleList()
|
54 |
+
self.downs = nn.ModuleList()
|
55 |
+
self.enc_cond = nn.ModuleList()
|
56 |
+
self.dec_cond = nn.ModuleList()
|
57 |
+
|
58 |
+
chan = width
|
59 |
+
for num in enc_blk_nums:
|
60 |
+
self.encoders.append(
|
61 |
+
nn.Sequential(
|
62 |
+
*[NAFBlock(chan) for _ in range(num)]
|
63 |
+
)
|
64 |
+
)
|
65 |
+
|
66 |
+
self.enc_cond.append(ICB(chan, txtdim))
|
67 |
+
|
68 |
+
self.downs.append(
|
69 |
+
nn.Conv2d(chan, 2*chan, 2, 2)
|
70 |
+
)
|
71 |
+
chan = chan * 2
|
72 |
+
|
73 |
+
self.middle_blks = nn.Sequential(
|
74 |
+
*[NAFBlock(chan) for _ in range(middle_blk_num)]
|
75 |
+
)
|
76 |
+
|
77 |
+
for num in dec_blk_nums:
|
78 |
+
self.ups.append(
|
79 |
+
nn.Sequential(
|
80 |
+
nn.Conv2d(chan, chan * 2, 1, bias=False),
|
81 |
+
nn.PixelShuffle(2)
|
82 |
+
)
|
83 |
+
)
|
84 |
+
chan = chan // 2
|
85 |
+
self.decoders.append(
|
86 |
+
nn.Sequential(
|
87 |
+
*[NAFBlock(chan) for _ in range(num)]
|
88 |
+
)
|
89 |
+
)
|
90 |
+
# Add text embedding as modulation
|
91 |
+
self.dec_cond.append(ICB(chan, txtdim))
|
92 |
+
|
93 |
+
self.padder_size = 2 ** len(self.encoders)
|
94 |
+
|
95 |
+
def forward(self, inp, txtembd):
|
96 |
+
B, C, H, W = inp.shape
|
97 |
+
inp = self.check_image_size(inp)
|
98 |
+
|
99 |
+
x = self.intro(inp)
|
100 |
+
encs = []
|
101 |
+
|
102 |
+
for encoder, enc_mod, down in zip(self.encoders, self.enc_cond, self.downs):
|
103 |
+
x = encoder(x)
|
104 |
+
x = enc_mod(x, txtembd)
|
105 |
+
encs.append(x)
|
106 |
+
x = down(x)
|
107 |
+
|
108 |
+
x = self.middle_blks(x)
|
109 |
+
|
110 |
+
for decoder, up, enc_skip, dec_mod in zip(self.decoders, self.ups, encs[::-1], self.dec_cond):
|
111 |
+
x = up(x)
|
112 |
+
x = x + enc_skip
|
113 |
+
x = decoder(x)
|
114 |
+
x = dec_mod(x, txtembd)
|
115 |
+
|
116 |
+
x = self.ending(x)
|
117 |
+
x = x + inp
|
118 |
+
|
119 |
+
return x[:, :, :H, :W]
|
120 |
+
|
121 |
+
def check_image_size(self, x):
|
122 |
+
_, _, h, w = x.size()
|
123 |
+
mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
|
124 |
+
mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
|
125 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
|
126 |
+
return x
|
127 |
+
|
128 |
+
|
129 |
+
def create_model(input_channels = 3, width = 32, enc_blks = [2, 2, 4, 8], middle_blk_num = 12, dec_blks = [2, 2, 2, 2], txtdim=768):
|
130 |
+
|
131 |
+
net = InstructIR(img_channel=input_channels, width=width, middle_blk_num=middle_blk_num,
|
132 |
+
enc_blk_nums=enc_blks, dec_blk_nums=dec_blks, txtdim=txtdim)
|
133 |
+
|
134 |
+
return net
|
insir_models/.ipynb_checkpoints/nafnet-checkpoint.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
3 |
+
# ------------------------------------------------------------------------
|
4 |
+
# Source: https://github.com/megvii-research/NAFNet
|
5 |
+
|
6 |
+
'''
|
7 |
+
Simple Baselines for Image Restoration
|
8 |
+
|
9 |
+
@article{chen2022simple,
|
10 |
+
title={Simple Baselines for Image Restoration},
|
11 |
+
author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian},
|
12 |
+
journal={arXiv preprint arXiv:2204.04676},
|
13 |
+
year={2022}
|
14 |
+
}
|
15 |
+
'''
|
16 |
+
|
17 |
+
import math
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from torch.nn import init as init
|
22 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
23 |
+
from insir_models.nafnet_utils import Local_Base, LayerNorm2d
|
24 |
+
|
25 |
+
|
26 |
+
class SimpleGate(nn.Module):
|
27 |
+
def forward(self, x):
|
28 |
+
x1, x2 = x.chunk(2, dim=1)
|
29 |
+
return x1 * x2
|
30 |
+
|
31 |
+
class NAFBlock(nn.Module):
|
32 |
+
def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
|
33 |
+
super().__init__()
|
34 |
+
dw_channel = c * DW_Expand
|
35 |
+
self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
36 |
+
self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
|
37 |
+
bias=True)
|
38 |
+
self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
39 |
+
|
40 |
+
# Simplified Channel Attention
|
41 |
+
self.sca = nn.Sequential(
|
42 |
+
nn.AdaptiveAvgPool2d(1),
|
43 |
+
nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
|
44 |
+
groups=1, bias=True),
|
45 |
+
)
|
46 |
+
|
47 |
+
# SimpleGate
|
48 |
+
self.sg = SimpleGate()
|
49 |
+
|
50 |
+
ffn_channel = FFN_Expand * c
|
51 |
+
self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
52 |
+
self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
53 |
+
|
54 |
+
self.norm1 = LayerNorm2d(c)
|
55 |
+
self.norm2 = LayerNorm2d(c)
|
56 |
+
|
57 |
+
self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
58 |
+
self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
59 |
+
|
60 |
+
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
61 |
+
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
62 |
+
|
63 |
+
def forward(self, inp):
|
64 |
+
x = inp
|
65 |
+
|
66 |
+
x = self.norm1(x)
|
67 |
+
|
68 |
+
x = self.conv1(x)
|
69 |
+
x = self.conv2(x)
|
70 |
+
x = self.sg(x)
|
71 |
+
x = x * self.sca(x)
|
72 |
+
x = self.conv3(x)
|
73 |
+
|
74 |
+
x = self.dropout1(x)
|
75 |
+
|
76 |
+
y = inp + x * self.beta
|
77 |
+
|
78 |
+
x = self.conv4(self.norm2(y))
|
79 |
+
x = self.sg(x)
|
80 |
+
x = self.conv5(x)
|
81 |
+
|
82 |
+
x = self.dropout2(x)
|
83 |
+
|
84 |
+
return y + x * self.gamma
|
85 |
+
|
86 |
+
|
87 |
+
class NAFNet(nn.Module):
|
88 |
+
|
89 |
+
def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[]):
|
90 |
+
super().__init__()
|
91 |
+
|
92 |
+
self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
|
93 |
+
bias=True)
|
94 |
+
self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
|
95 |
+
bias=True)
|
96 |
+
|
97 |
+
self.encoders = nn.ModuleList()
|
98 |
+
self.decoders = nn.ModuleList()
|
99 |
+
self.middle_blks = nn.ModuleList()
|
100 |
+
self.ups = nn.ModuleList()
|
101 |
+
self.downs = nn.ModuleList()
|
102 |
+
|
103 |
+
chan = width
|
104 |
+
for num in enc_blk_nums:
|
105 |
+
self.encoders.append(
|
106 |
+
nn.Sequential(
|
107 |
+
*[NAFBlock(chan) for _ in range(num)]
|
108 |
+
)
|
109 |
+
)
|
110 |
+
self.downs.append(
|
111 |
+
nn.Conv2d(chan, 2*chan, 2, 2)
|
112 |
+
)
|
113 |
+
chan = chan * 2
|
114 |
+
|
115 |
+
self.middle_blks = \
|
116 |
+
nn.Sequential(
|
117 |
+
*[NAFBlock(chan) for _ in range(middle_blk_num)]
|
118 |
+
)
|
119 |
+
|
120 |
+
for num in dec_blk_nums:
|
121 |
+
self.ups.append(
|
122 |
+
nn.Sequential(
|
123 |
+
nn.Conv2d(chan, chan * 2, 1, bias=False),
|
124 |
+
nn.PixelShuffle(2)
|
125 |
+
)
|
126 |
+
)
|
127 |
+
chan = chan // 2
|
128 |
+
self.decoders.append(
|
129 |
+
nn.Sequential(
|
130 |
+
*[NAFBlock(chan) for _ in range(num)]
|
131 |
+
)
|
132 |
+
)
|
133 |
+
|
134 |
+
self.padder_size = 2 ** len(self.encoders)
|
135 |
+
|
136 |
+
def forward(self, inp):
|
137 |
+
B, C, H, W = inp.shape
|
138 |
+
inp = self.check_image_size(inp)
|
139 |
+
|
140 |
+
x = self.intro(inp)
|
141 |
+
|
142 |
+
encs = []
|
143 |
+
|
144 |
+
for encoder, down in zip(self.encoders, self.downs):
|
145 |
+
x = encoder(x)
|
146 |
+
encs.append(x)
|
147 |
+
x = down(x)
|
148 |
+
|
149 |
+
x = self.middle_blks(x)
|
150 |
+
|
151 |
+
for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
|
152 |
+
x = up(x)
|
153 |
+
x = x + enc_skip
|
154 |
+
x = decoder(x)
|
155 |
+
|
156 |
+
x = self.ending(x)
|
157 |
+
x = x + inp
|
158 |
+
|
159 |
+
return x[:, :, :H, :W]
|
160 |
+
|
161 |
+
def check_image_size(self, x):
|
162 |
+
_, _, h, w = x.size()
|
163 |
+
mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
|
164 |
+
mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
|
165 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
|
166 |
+
return x
|
167 |
+
|
168 |
+
class NAFNetLocal(Local_Base, NAFNet):
|
169 |
+
def __init__(self, *args, train_size=(1, 3, 256, 256), fast_imp=False, **kwargs):
|
170 |
+
Local_Base.__init__(self)
|
171 |
+
NAFNet.__init__(self, *args, **kwargs)
|
172 |
+
|
173 |
+
N, C, H, W = train_size
|
174 |
+
base_size = (int(H * 1.5), int(W * 1.5))
|
175 |
+
|
176 |
+
self.eval()
|
177 |
+
with torch.no_grad():
|
178 |
+
self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp)
|
179 |
+
|
180 |
+
|
181 |
+
def create_nafnet(input_channels = 3, width = 32, enc_blks = [2, 2, 4, 8], middle_blk_num = 12, dec_blks = [2, 2, 2, 2]):
|
182 |
+
"""
|
183 |
+
Create Nafnet model
|
184 |
+
https://github.com/megvii-research/NAFNet/blob/main/options/test/SIDD/NAFNet-width32.yml
|
185 |
+
"""
|
186 |
+
|
187 |
+
net = NAFNet(img_channel=input_channels, width=width, middle_blk_num=middle_blk_num,
|
188 |
+
enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)
|
189 |
+
|
190 |
+
# inp_shape = (3, 256, 256)
|
191 |
+
|
192 |
+
# from ptflops import get_model_complexity_info
|
193 |
+
|
194 |
+
# macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False)
|
195 |
+
|
196 |
+
# params = float(params[:-3])
|
197 |
+
# macs = float(macs[:-4])
|
198 |
+
|
199 |
+
# print(macs, params)
|
200 |
+
|
201 |
+
return net
|
insir_models/.ipynb_checkpoints/nafnet_utils-checkpoint.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
3 |
+
# ------------------------------------------------------------------------
|
4 |
+
# Source: https://github.com/megvii-research/NAFNet
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import math
|
11 |
+
|
12 |
+
class LayerNormFunction(torch.autograd.Function):
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
def forward(ctx, x, weight, bias, eps):
|
16 |
+
ctx.eps = eps
|
17 |
+
N, C, H, W = x.size()
|
18 |
+
mu = x.mean(1, keepdim=True)
|
19 |
+
var = (x - mu).pow(2).mean(1, keepdim=True)
|
20 |
+
y = (x - mu) / (var + eps).sqrt()
|
21 |
+
ctx.save_for_backward(y, var, weight)
|
22 |
+
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
|
23 |
+
return y
|
24 |
+
|
25 |
+
@staticmethod
|
26 |
+
def backward(ctx, grad_output):
|
27 |
+
eps = ctx.eps
|
28 |
+
|
29 |
+
N, C, H, W = grad_output.size()
|
30 |
+
y, var, weight = ctx.saved_variables
|
31 |
+
g = grad_output * weight.view(1, C, 1, 1)
|
32 |
+
mean_g = g.mean(dim=1, keepdim=True)
|
33 |
+
|
34 |
+
mean_gy = (g * y).mean(dim=1, keepdim=True)
|
35 |
+
gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
|
36 |
+
return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
|
37 |
+
dim=0), None
|
38 |
+
|
39 |
+
class LayerNorm2d(nn.Module):
|
40 |
+
|
41 |
+
def __init__(self, channels, eps=1e-6):
|
42 |
+
super(LayerNorm2d, self).__init__()
|
43 |
+
self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
|
44 |
+
self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
|
45 |
+
self.eps = eps
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
class AvgPool2d(nn.Module):
|
53 |
+
def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None):
|
54 |
+
super().__init__()
|
55 |
+
self.kernel_size = kernel_size
|
56 |
+
self.base_size = base_size
|
57 |
+
self.auto_pad = auto_pad
|
58 |
+
|
59 |
+
# only used for fast implementation
|
60 |
+
self.fast_imp = fast_imp
|
61 |
+
self.rs = [5, 4, 3, 2, 1]
|
62 |
+
self.max_r1 = self.rs[0]
|
63 |
+
self.max_r2 = self.rs[0]
|
64 |
+
self.train_size = train_size
|
65 |
+
|
66 |
+
def extra_repr(self) -> str:
|
67 |
+
return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format(
|
68 |
+
self.kernel_size, self.base_size, self.kernel_size, self.fast_imp
|
69 |
+
)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
if self.kernel_size is None and self.base_size:
|
73 |
+
train_size = self.train_size
|
74 |
+
if isinstance(self.base_size, int):
|
75 |
+
self.base_size = (self.base_size, self.base_size)
|
76 |
+
self.kernel_size = list(self.base_size)
|
77 |
+
self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2]
|
78 |
+
self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1]
|
79 |
+
|
80 |
+
# only used for fast implementation
|
81 |
+
self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2])
|
82 |
+
self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1])
|
83 |
+
|
84 |
+
if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1):
|
85 |
+
return F.adaptive_avg_pool2d(x, 1)
|
86 |
+
|
87 |
+
if self.fast_imp: # Non-equivalent implementation but faster
|
88 |
+
h, w = x.shape[2:]
|
89 |
+
if self.kernel_size[0] >= h and self.kernel_size[1] >= w:
|
90 |
+
out = F.adaptive_avg_pool2d(x, 1)
|
91 |
+
else:
|
92 |
+
r1 = [r for r in self.rs if h % r == 0][0]
|
93 |
+
r2 = [r for r in self.rs if w % r == 0][0]
|
94 |
+
# reduction_constraint
|
95 |
+
r1 = min(self.max_r1, r1)
|
96 |
+
r2 = min(self.max_r2, r2)
|
97 |
+
s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2)
|
98 |
+
n, c, h, w = s.shape
|
99 |
+
k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2)
|
100 |
+
out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2)
|
101 |
+
out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2))
|
102 |
+
else:
|
103 |
+
n, c, h, w = x.shape
|
104 |
+
s = x.cumsum(dim=-1).cumsum_(dim=-2)
|
105 |
+
s = torch.nn.functional.pad(s, (1, 0, 1, 0)) # pad 0 for convenience
|
106 |
+
k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1])
|
107 |
+
s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:]
|
108 |
+
out = s4 + s1 - s2 - s3
|
109 |
+
out = out / (k1 * k2)
|
110 |
+
|
111 |
+
if self.auto_pad:
|
112 |
+
n, c, h, w = x.shape
|
113 |
+
_h, _w = out.shape[2:]
|
114 |
+
# print(x.shape, self.kernel_size)
|
115 |
+
pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2)
|
116 |
+
out = torch.nn.functional.pad(out, pad2d, mode='replicate')
|
117 |
+
|
118 |
+
return out
|
119 |
+
|
120 |
+
def replace_layers(model, base_size, train_size, fast_imp, **kwargs):
|
121 |
+
for n, m in model.named_children():
|
122 |
+
if len(list(m.children())) > 0:
|
123 |
+
## compound module, go inside it
|
124 |
+
replace_layers(m, base_size, train_size, fast_imp, **kwargs)
|
125 |
+
|
126 |
+
if isinstance(m, nn.AdaptiveAvgPool2d):
|
127 |
+
pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size)
|
128 |
+
assert m.output_size == 1
|
129 |
+
setattr(model, n, pool)
|
130 |
+
|
131 |
+
|
132 |
+
'''
|
133 |
+
ref.
|
134 |
+
@article{chu2021tlsc,
|
135 |
+
title={Revisiting Global Statistics Aggregation for Improving Image Restoration},
|
136 |
+
author={Chu, Xiaojie and Chen, Liangyu and and Chen, Chengpeng and Lu, Xin},
|
137 |
+
journal={arXiv preprint arXiv:2112.04491},
|
138 |
+
year={2021}
|
139 |
+
}
|
140 |
+
'''
|
141 |
+
class Local_Base():
|
142 |
+
def convert(self, *args, train_size, **kwargs):
|
143 |
+
replace_layers(self, *args, train_size=train_size, **kwargs)
|
144 |
+
imgs = torch.rand(train_size)
|
145 |
+
with torch.no_grad():
|
146 |
+
self.forward(imgs)
|
insir_models/__pycache__/instructir.cpython-39.pyc
ADDED
Binary file (4.22 kB). View file
|
|
insir_models/__pycache__/nafnet.cpython-39.pyc
ADDED
Binary file (5.53 kB). View file
|
|
insir_models/__pycache__/nafnet_utils.cpython-39.pyc
ADDED
Binary file (5.4 kB). View file
|
|
insir_models/instructir.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.nn import init as init
|
6 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
7 |
+
|
8 |
+
from insir_models.nafnet_utils import Local_Base, LayerNorm2d
|
9 |
+
from insir_models.nafnet import SimpleGate, NAFBlock
|
10 |
+
|
11 |
+
|
12 |
+
class ICB(nn.Module):
|
13 |
+
"""
|
14 |
+
Instruction Condition Block (ICB)
|
15 |
+
Paper Section 3.3
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, feature_dim, text_dim=768):
|
19 |
+
super(ICB, self).__init__()
|
20 |
+
self.fc = nn.Linear(text_dim, feature_dim)
|
21 |
+
self.block = NAFBlock(feature_dim)
|
22 |
+
self.beta = nn.Parameter(torch.zeros((1, feature_dim, 1, 1)), requires_grad=True)
|
23 |
+
self.gamma = nn.Parameter(torch.zeros((1, feature_dim, 1, 1)), requires_grad=True)
|
24 |
+
|
25 |
+
def forward(self, x, text_embedding):
|
26 |
+
gating_factors = torch.sigmoid(self.fc(text_embedding))
|
27 |
+
gating_factors = gating_factors.unsqueeze(-1).unsqueeze(-1)
|
28 |
+
|
29 |
+
f = x * self.gamma + self.beta # 1) learned feature scaling/modulation
|
30 |
+
f = f * gating_factors # 2) (soft) feature routing based on text
|
31 |
+
f = self.block(f) # 3) block feature enhancement
|
32 |
+
return f + x
|
33 |
+
|
34 |
+
|
35 |
+
class InstructIR(nn.Module):
|
36 |
+
"""
|
37 |
+
InstructIR model using NAFNet (ECCV 2022) as backbone.
|
38 |
+
The model takes as input an RGB image and a text embedding (encoded instruction).
|
39 |
+
Described in Paper Section 3.3
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[], txtdim=768):
|
43 |
+
super().__init__()
|
44 |
+
|
45 |
+
self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
|
46 |
+
bias=True)
|
47 |
+
self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
|
48 |
+
bias=True)
|
49 |
+
|
50 |
+
self.encoders = nn.ModuleList()
|
51 |
+
self.decoders = nn.ModuleList()
|
52 |
+
self.middle_blks = nn.ModuleList()
|
53 |
+
self.ups = nn.ModuleList()
|
54 |
+
self.downs = nn.ModuleList()
|
55 |
+
self.enc_cond = nn.ModuleList()
|
56 |
+
self.dec_cond = nn.ModuleList()
|
57 |
+
|
58 |
+
chan = width
|
59 |
+
for num in enc_blk_nums:
|
60 |
+
self.encoders.append(
|
61 |
+
nn.Sequential(
|
62 |
+
*[NAFBlock(chan) for _ in range(num)]
|
63 |
+
)
|
64 |
+
)
|
65 |
+
|
66 |
+
self.enc_cond.append(ICB(chan, txtdim))
|
67 |
+
|
68 |
+
self.downs.append(
|
69 |
+
nn.Conv2d(chan, 2*chan, 2, 2)
|
70 |
+
)
|
71 |
+
chan = chan * 2
|
72 |
+
|
73 |
+
self.middle_blks = nn.Sequential(
|
74 |
+
*[NAFBlock(chan) for _ in range(middle_blk_num)]
|
75 |
+
)
|
76 |
+
|
77 |
+
for num in dec_blk_nums:
|
78 |
+
self.ups.append(
|
79 |
+
nn.Sequential(
|
80 |
+
nn.Conv2d(chan, chan * 2, 1, bias=False),
|
81 |
+
nn.PixelShuffle(2)
|
82 |
+
)
|
83 |
+
)
|
84 |
+
chan = chan // 2
|
85 |
+
self.decoders.append(
|
86 |
+
nn.Sequential(
|
87 |
+
*[NAFBlock(chan) for _ in range(num)]
|
88 |
+
)
|
89 |
+
)
|
90 |
+
# Add text embedding as modulation
|
91 |
+
self.dec_cond.append(ICB(chan, txtdim))
|
92 |
+
|
93 |
+
self.padder_size = 2 ** len(self.encoders)
|
94 |
+
|
95 |
+
def forward(self, inp, txtembd):
|
96 |
+
B, C, H, W = inp.shape
|
97 |
+
inp = self.check_image_size(inp)
|
98 |
+
|
99 |
+
x = self.intro(inp)
|
100 |
+
encs = []
|
101 |
+
|
102 |
+
for encoder, enc_mod, down in zip(self.encoders, self.enc_cond, self.downs):
|
103 |
+
x = encoder(x)
|
104 |
+
x = enc_mod(x, txtembd)
|
105 |
+
encs.append(x)
|
106 |
+
x = down(x)
|
107 |
+
|
108 |
+
x = self.middle_blks(x)
|
109 |
+
|
110 |
+
for decoder, up, enc_skip, dec_mod in zip(self.decoders, self.ups, encs[::-1], self.dec_cond):
|
111 |
+
x = up(x)
|
112 |
+
x = x + enc_skip
|
113 |
+
x = decoder(x)
|
114 |
+
x = dec_mod(x, txtembd)
|
115 |
+
|
116 |
+
x = self.ending(x)
|
117 |
+
x = x + inp
|
118 |
+
|
119 |
+
return x[:, :, :H, :W]
|
120 |
+
|
121 |
+
def check_image_size(self, x):
|
122 |
+
_, _, h, w = x.size()
|
123 |
+
mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
|
124 |
+
mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
|
125 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
|
126 |
+
return x
|
127 |
+
|
128 |
+
|
129 |
+
def create_model(input_channels = 3, width = 32, enc_blks = [2, 2, 4, 8], middle_blk_num = 12, dec_blks = [2, 2, 2, 2], txtdim=768):
|
130 |
+
|
131 |
+
net = InstructIR(img_channel=input_channels, width=width, middle_blk_num=middle_blk_num,
|
132 |
+
enc_blk_nums=enc_blks, dec_blk_nums=dec_blks, txtdim=txtdim)
|
133 |
+
|
134 |
+
return net
|
insir_models/nafnet.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
3 |
+
# ------------------------------------------------------------------------
|
4 |
+
# Source: https://github.com/megvii-research/NAFNet
|
5 |
+
|
6 |
+
'''
|
7 |
+
Simple Baselines for Image Restoration
|
8 |
+
|
9 |
+
@article{chen2022simple,
|
10 |
+
title={Simple Baselines for Image Restoration},
|
11 |
+
author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian},
|
12 |
+
journal={arXiv preprint arXiv:2204.04676},
|
13 |
+
year={2022}
|
14 |
+
}
|
15 |
+
'''
|
16 |
+
|
17 |
+
import math
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from torch.nn import init as init
|
22 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
23 |
+
from insir_models.nafnet_utils import Local_Base, LayerNorm2d
|
24 |
+
|
25 |
+
|
26 |
+
class SimpleGate(nn.Module):
|
27 |
+
def forward(self, x):
|
28 |
+
x1, x2 = x.chunk(2, dim=1)
|
29 |
+
return x1 * x2
|
30 |
+
|
31 |
+
class NAFBlock(nn.Module):
|
32 |
+
def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
|
33 |
+
super().__init__()
|
34 |
+
dw_channel = c * DW_Expand
|
35 |
+
self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
36 |
+
self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
|
37 |
+
bias=True)
|
38 |
+
self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
39 |
+
|
40 |
+
# Simplified Channel Attention
|
41 |
+
self.sca = nn.Sequential(
|
42 |
+
nn.AdaptiveAvgPool2d(1),
|
43 |
+
nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
|
44 |
+
groups=1, bias=True),
|
45 |
+
)
|
46 |
+
|
47 |
+
# SimpleGate
|
48 |
+
self.sg = SimpleGate()
|
49 |
+
|
50 |
+
ffn_channel = FFN_Expand * c
|
51 |
+
self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
52 |
+
self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
53 |
+
|
54 |
+
self.norm1 = LayerNorm2d(c)
|
55 |
+
self.norm2 = LayerNorm2d(c)
|
56 |
+
|
57 |
+
self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
58 |
+
self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
59 |
+
|
60 |
+
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
61 |
+
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
62 |
+
|
63 |
+
def forward(self, inp):
|
64 |
+
x = inp
|
65 |
+
|
66 |
+
x = self.norm1(x)
|
67 |
+
|
68 |
+
x = self.conv1(x)
|
69 |
+
x = self.conv2(x)
|
70 |
+
x = self.sg(x)
|
71 |
+
x = x * self.sca(x)
|
72 |
+
x = self.conv3(x)
|
73 |
+
|
74 |
+
x = self.dropout1(x)
|
75 |
+
|
76 |
+
y = inp + x * self.beta
|
77 |
+
|
78 |
+
x = self.conv4(self.norm2(y))
|
79 |
+
x = self.sg(x)
|
80 |
+
x = self.conv5(x)
|
81 |
+
|
82 |
+
x = self.dropout2(x)
|
83 |
+
|
84 |
+
return y + x * self.gamma
|
85 |
+
|
86 |
+
|
87 |
+
class NAFNet(nn.Module):
|
88 |
+
|
89 |
+
def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[]):
|
90 |
+
super().__init__()
|
91 |
+
|
92 |
+
self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
|
93 |
+
bias=True)
|
94 |
+
self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
|
95 |
+
bias=True)
|
96 |
+
|
97 |
+
self.encoders = nn.ModuleList()
|
98 |
+
self.decoders = nn.ModuleList()
|
99 |
+
self.middle_blks = nn.ModuleList()
|
100 |
+
self.ups = nn.ModuleList()
|
101 |
+
self.downs = nn.ModuleList()
|
102 |
+
|
103 |
+
chan = width
|
104 |
+
for num in enc_blk_nums:
|
105 |
+
self.encoders.append(
|
106 |
+
nn.Sequential(
|
107 |
+
*[NAFBlock(chan) for _ in range(num)]
|
108 |
+
)
|
109 |
+
)
|
110 |
+
self.downs.append(
|
111 |
+
nn.Conv2d(chan, 2*chan, 2, 2)
|
112 |
+
)
|
113 |
+
chan = chan * 2
|
114 |
+
|
115 |
+
self.middle_blks = \
|
116 |
+
nn.Sequential(
|
117 |
+
*[NAFBlock(chan) for _ in range(middle_blk_num)]
|
118 |
+
)
|
119 |
+
|
120 |
+
for num in dec_blk_nums:
|
121 |
+
self.ups.append(
|
122 |
+
nn.Sequential(
|
123 |
+
nn.Conv2d(chan, chan * 2, 1, bias=False),
|
124 |
+
nn.PixelShuffle(2)
|
125 |
+
)
|
126 |
+
)
|
127 |
+
chan = chan // 2
|
128 |
+
self.decoders.append(
|
129 |
+
nn.Sequential(
|
130 |
+
*[NAFBlock(chan) for _ in range(num)]
|
131 |
+
)
|
132 |
+
)
|
133 |
+
|
134 |
+
self.padder_size = 2 ** len(self.encoders)
|
135 |
+
|
136 |
+
def forward(self, inp):
|
137 |
+
B, C, H, W = inp.shape
|
138 |
+
inp = self.check_image_size(inp)
|
139 |
+
|
140 |
+
x = self.intro(inp)
|
141 |
+
|
142 |
+
encs = []
|
143 |
+
|
144 |
+
for encoder, down in zip(self.encoders, self.downs):
|
145 |
+
x = encoder(x)
|
146 |
+
encs.append(x)
|
147 |
+
x = down(x)
|
148 |
+
|
149 |
+
x = self.middle_blks(x)
|
150 |
+
|
151 |
+
for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
|
152 |
+
x = up(x)
|
153 |
+
x = x + enc_skip
|
154 |
+
x = decoder(x)
|
155 |
+
|
156 |
+
x = self.ending(x)
|
157 |
+
x = x + inp
|
158 |
+
|
159 |
+
return x[:, :, :H, :W]
|
160 |
+
|
161 |
+
def check_image_size(self, x):
|
162 |
+
_, _, h, w = x.size()
|
163 |
+
mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
|
164 |
+
mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
|
165 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
|
166 |
+
return x
|
167 |
+
|
168 |
+
class NAFNetLocal(Local_Base, NAFNet):
|
169 |
+
def __init__(self, *args, train_size=(1, 3, 256, 256), fast_imp=False, **kwargs):
|
170 |
+
Local_Base.__init__(self)
|
171 |
+
NAFNet.__init__(self, *args, **kwargs)
|
172 |
+
|
173 |
+
N, C, H, W = train_size
|
174 |
+
base_size = (int(H * 1.5), int(W * 1.5))
|
175 |
+
|
176 |
+
self.eval()
|
177 |
+
with torch.no_grad():
|
178 |
+
self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp)
|
179 |
+
|
180 |
+
|
181 |
+
def create_nafnet(input_channels = 3, width = 32, enc_blks = [2, 2, 4, 8], middle_blk_num = 12, dec_blks = [2, 2, 2, 2]):
|
182 |
+
"""
|
183 |
+
Create Nafnet model
|
184 |
+
https://github.com/megvii-research/NAFNet/blob/main/options/test/SIDD/NAFNet-width32.yml
|
185 |
+
"""
|
186 |
+
|
187 |
+
net = NAFNet(img_channel=input_channels, width=width, middle_blk_num=middle_blk_num,
|
188 |
+
enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)
|
189 |
+
|
190 |
+
# inp_shape = (3, 256, 256)
|
191 |
+
|
192 |
+
# from ptflops import get_model_complexity_info
|
193 |
+
|
194 |
+
# macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False)
|
195 |
+
|
196 |
+
# params = float(params[:-3])
|
197 |
+
# macs = float(macs[:-4])
|
198 |
+
|
199 |
+
# print(macs, params)
|
200 |
+
|
201 |
+
return net
|
insir_models/nafnet_utils.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
3 |
+
# ------------------------------------------------------------------------
|
4 |
+
# Source: https://github.com/megvii-research/NAFNet
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import math
|
11 |
+
|
12 |
+
class LayerNormFunction(torch.autograd.Function):
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
def forward(ctx, x, weight, bias, eps):
|
16 |
+
ctx.eps = eps
|
17 |
+
N, C, H, W = x.size()
|
18 |
+
mu = x.mean(1, keepdim=True)
|
19 |
+
var = (x - mu).pow(2).mean(1, keepdim=True)
|
20 |
+
y = (x - mu) / (var + eps).sqrt()
|
21 |
+
ctx.save_for_backward(y, var, weight)
|
22 |
+
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
|
23 |
+
return y
|
24 |
+
|
25 |
+
@staticmethod
|
26 |
+
def backward(ctx, grad_output):
|
27 |
+
eps = ctx.eps
|
28 |
+
|
29 |
+
N, C, H, W = grad_output.size()
|
30 |
+
y, var, weight = ctx.saved_variables
|
31 |
+
g = grad_output * weight.view(1, C, 1, 1)
|
32 |
+
mean_g = g.mean(dim=1, keepdim=True)
|
33 |
+
|
34 |
+
mean_gy = (g * y).mean(dim=1, keepdim=True)
|
35 |
+
gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
|
36 |
+
return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
|
37 |
+
dim=0), None
|
38 |
+
|
39 |
+
class LayerNorm2d(nn.Module):
|
40 |
+
|
41 |
+
def __init__(self, channels, eps=1e-6):
|
42 |
+
super(LayerNorm2d, self).__init__()
|
43 |
+
self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
|
44 |
+
self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
|
45 |
+
self.eps = eps
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
class AvgPool2d(nn.Module):
|
53 |
+
def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None):
|
54 |
+
super().__init__()
|
55 |
+
self.kernel_size = kernel_size
|
56 |
+
self.base_size = base_size
|
57 |
+
self.auto_pad = auto_pad
|
58 |
+
|
59 |
+
# only used for fast implementation
|
60 |
+
self.fast_imp = fast_imp
|
61 |
+
self.rs = [5, 4, 3, 2, 1]
|
62 |
+
self.max_r1 = self.rs[0]
|
63 |
+
self.max_r2 = self.rs[0]
|
64 |
+
self.train_size = train_size
|
65 |
+
|
66 |
+
def extra_repr(self) -> str:
|
67 |
+
return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format(
|
68 |
+
self.kernel_size, self.base_size, self.kernel_size, self.fast_imp
|
69 |
+
)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
if self.kernel_size is None and self.base_size:
|
73 |
+
train_size = self.train_size
|
74 |
+
if isinstance(self.base_size, int):
|
75 |
+
self.base_size = (self.base_size, self.base_size)
|
76 |
+
self.kernel_size = list(self.base_size)
|
77 |
+
self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2]
|
78 |
+
self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1]
|
79 |
+
|
80 |
+
# only used for fast implementation
|
81 |
+
self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2])
|
82 |
+
self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1])
|
83 |
+
|
84 |
+
if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1):
|
85 |
+
return F.adaptive_avg_pool2d(x, 1)
|
86 |
+
|
87 |
+
if self.fast_imp: # Non-equivalent implementation but faster
|
88 |
+
h, w = x.shape[2:]
|
89 |
+
if self.kernel_size[0] >= h and self.kernel_size[1] >= w:
|
90 |
+
out = F.adaptive_avg_pool2d(x, 1)
|
91 |
+
else:
|
92 |
+
r1 = [r for r in self.rs if h % r == 0][0]
|
93 |
+
r2 = [r for r in self.rs if w % r == 0][0]
|
94 |
+
# reduction_constraint
|
95 |
+
r1 = min(self.max_r1, r1)
|
96 |
+
r2 = min(self.max_r2, r2)
|
97 |
+
s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2)
|
98 |
+
n, c, h, w = s.shape
|
99 |
+
k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2)
|
100 |
+
out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2)
|
101 |
+
out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2))
|
102 |
+
else:
|
103 |
+
n, c, h, w = x.shape
|
104 |
+
s = x.cumsum(dim=-1).cumsum_(dim=-2)
|
105 |
+
s = torch.nn.functional.pad(s, (1, 0, 1, 0)) # pad 0 for convenience
|
106 |
+
k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1])
|
107 |
+
s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:]
|
108 |
+
out = s4 + s1 - s2 - s3
|
109 |
+
out = out / (k1 * k2)
|
110 |
+
|
111 |
+
if self.auto_pad:
|
112 |
+
n, c, h, w = x.shape
|
113 |
+
_h, _w = out.shape[2:]
|
114 |
+
# print(x.shape, self.kernel_size)
|
115 |
+
pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2)
|
116 |
+
out = torch.nn.functional.pad(out, pad2d, mode='replicate')
|
117 |
+
|
118 |
+
return out
|
119 |
+
|
120 |
+
def replace_layers(model, base_size, train_size, fast_imp, **kwargs):
|
121 |
+
for n, m in model.named_children():
|
122 |
+
if len(list(m.children())) > 0:
|
123 |
+
## compound module, go inside it
|
124 |
+
replace_layers(m, base_size, train_size, fast_imp, **kwargs)
|
125 |
+
|
126 |
+
if isinstance(m, nn.AdaptiveAvgPool2d):
|
127 |
+
pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size)
|
128 |
+
assert m.output_size == 1
|
129 |
+
setattr(model, n, pool)
|
130 |
+
|
131 |
+
|
132 |
+
'''
|
133 |
+
ref.
|
134 |
+
@article{chu2021tlsc,
|
135 |
+
title={Revisiting Global Statistics Aggregation for Improving Image Restoration},
|
136 |
+
author={Chu, Xiaojie and Chen, Liangyu and and Chen, Chengpeng and Lu, Xin},
|
137 |
+
journal={arXiv preprint arXiv:2112.04491},
|
138 |
+
year={2021}
|
139 |
+
}
|
140 |
+
'''
|
141 |
+
class Local_Base():
|
142 |
+
def convert(self, *args, train_size, **kwargs):
|
143 |
+
replace_layers(self, *args, train_size=train_size, **kwargs)
|
144 |
+
imgs = torch.rand(train_size)
|
145 |
+
with torch.no_grad():
|
146 |
+
self.forward(imgs)
|
insir_text/.ipynb_checkpoints/models-checkpoint.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from transformers import DistilBertModel, DistilBertTokenizer, AutoModel, AutoTokenizer
|
5 |
+
import os
|
6 |
+
|
7 |
+
# Models that use mean pooling
|
8 |
+
POOL_MODELS = {"sentence-transformers/all-MiniLM-L6-v2", "TaylorAI/bge-micro-v2"}
|
9 |
+
|
10 |
+
#Mean Pooling - Take attention mask into account for correct averaging
|
11 |
+
def mean_pooling(model_output, attention_mask):
|
12 |
+
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
|
13 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
14 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
15 |
+
|
16 |
+
|
17 |
+
class LanguageModel(nn.Module):
|
18 |
+
def __init__(self, model='distilbert-base-uncased'):
|
19 |
+
super(LanguageModel, self).__init__()
|
20 |
+
|
21 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model)
|
22 |
+
self.model = AutoModel.from_pretrained(model)
|
23 |
+
self.model_name = model
|
24 |
+
# Remove the CLIP vision tower
|
25 |
+
if "clip" in self.model_name:
|
26 |
+
self.model.vision_model = None
|
27 |
+
# Freeze the pre-trained parameters (very important)
|
28 |
+
for param in self.model.parameters():
|
29 |
+
param.requires_grad = False
|
30 |
+
|
31 |
+
# Make sure to set evaluation mode (also important)
|
32 |
+
self.model.eval()
|
33 |
+
|
34 |
+
def forward(self, text_batch):
|
35 |
+
inputs = self.tokenizer(text_batch, padding=True, truncation=True, return_tensors="pt")
|
36 |
+
with torch.no_grad(): # Ensure no gradients are computed for this forward pass
|
37 |
+
|
38 |
+
if "clip" in self.model_name:
|
39 |
+
sentence_embedding = self.model.get_text_features(**inputs)
|
40 |
+
return sentence_embedding
|
41 |
+
|
42 |
+
outputs = self.model(**inputs)
|
43 |
+
|
44 |
+
if any(model in self.model_name for model in POOL_MODELS):
|
45 |
+
sentence_embeddings = mean_pooling(outputs, inputs['attention_mask'])
|
46 |
+
# Normalize embeddings
|
47 |
+
sentence_embedding = F.normalize(sentence_embeddings, p=2, dim=1)
|
48 |
+
else:
|
49 |
+
sentence_embedding = outputs.last_hidden_state[:, 0, :]
|
50 |
+
return sentence_embedding
|
51 |
+
|
52 |
+
|
53 |
+
class LMHead(nn.Module):
|
54 |
+
def __init__(self, embedding_dim=384, hidden_dim=256, num_classes=4):
|
55 |
+
super(LMHead, self).__init__()
|
56 |
+
|
57 |
+
self.fc1 = nn.Linear(embedding_dim, hidden_dim)
|
58 |
+
#self.gelu = nn.GELU()
|
59 |
+
self.fc2 = nn.Linear(hidden_dim, num_classes)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
embd = self.fc1(x)
|
63 |
+
embd = F.normalize(embd, p=2, dim=1)
|
64 |
+
deg_pred = self.fc2(embd)
|
65 |
+
return embd, deg_pred
|
insir_text/__pycache__/models.cpython-39.pyc
ADDED
Binary file (2.72 kB). View file
|
|
insir_text/models.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from transformers import DistilBertModel, DistilBertTokenizer, AutoModel, AutoTokenizer
|
5 |
+
import os
|
6 |
+
|
7 |
+
# Models that use mean pooling
|
8 |
+
POOL_MODELS = {"sentence-transformers/all-MiniLM-L6-v2", "TaylorAI/bge-micro-v2"}
|
9 |
+
|
10 |
+
#Mean Pooling - Take attention mask into account for correct averaging
|
11 |
+
def mean_pooling(model_output, attention_mask):
|
12 |
+
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
|
13 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
14 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
15 |
+
|
16 |
+
|
17 |
+
class LanguageModel(nn.Module):
|
18 |
+
def __init__(self, model='distilbert-base-uncased'):
|
19 |
+
super(LanguageModel, self).__init__()
|
20 |
+
|
21 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model)
|
22 |
+
self.model = AutoModel.from_pretrained(model)
|
23 |
+
self.model_name = model
|
24 |
+
# Remove the CLIP vision tower
|
25 |
+
if "clip" in self.model_name:
|
26 |
+
self.model.vision_model = None
|
27 |
+
# Freeze the pre-trained parameters (very important)
|
28 |
+
for param in self.model.parameters():
|
29 |
+
param.requires_grad = False
|
30 |
+
|
31 |
+
# Make sure to set evaluation mode (also important)
|
32 |
+
self.model.eval()
|
33 |
+
|
34 |
+
def forward(self, text_batch):
|
35 |
+
inputs = self.tokenizer(text_batch, padding=True, truncation=True, return_tensors="pt")
|
36 |
+
with torch.no_grad(): # Ensure no gradients are computed for this forward pass
|
37 |
+
|
38 |
+
if "clip" in self.model_name:
|
39 |
+
sentence_embedding = self.model.get_text_features(**inputs)
|
40 |
+
return sentence_embedding
|
41 |
+
|
42 |
+
outputs = self.model(**inputs)
|
43 |
+
|
44 |
+
if any(model in self.model_name for model in POOL_MODELS):
|
45 |
+
sentence_embeddings = mean_pooling(outputs, inputs['attention_mask'])
|
46 |
+
# Normalize embeddings
|
47 |
+
sentence_embedding = F.normalize(sentence_embeddings, p=2, dim=1)
|
48 |
+
else:
|
49 |
+
sentence_embedding = outputs.last_hidden_state[:, 0, :]
|
50 |
+
return sentence_embedding
|
51 |
+
|
52 |
+
|
53 |
+
class LMHead(nn.Module):
|
54 |
+
def __init__(self, embedding_dim=384, hidden_dim=256, num_classes=4):
|
55 |
+
super(LMHead, self).__init__()
|
56 |
+
|
57 |
+
self.fc1 = nn.Linear(embedding_dim, hidden_dim)
|
58 |
+
#self.gelu = nn.GELU()
|
59 |
+
self.fc2 = nn.Linear(hidden_dim, num_classes)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
embd = self.fc1(x)
|
63 |
+
embd = F.normalize(embd, p=2, dim=1)
|
64 |
+
deg_pred = self.fc2(embd)
|
65 |
+
return embd, deg_pred
|
insir_text/sample_prompts.json
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"denoising": [
|
3 |
+
"Help me reduce the fuzziness in this image.",
|
4 |
+
"I need this image denoised ASAP.",
|
5 |
+
"Clean up this noisy image, it's an eyesore.",
|
6 |
+
"Can you clean the dots from my image?",
|
7 |
+
"Help me with my picture, it's full of tiny spots.",
|
8 |
+
"Clean up this image, it's all grainy."
|
9 |
+
],
|
10 |
+
"deblurring": [
|
11 |
+
"Please, clean up this blurry photo.",
|
12 |
+
"My picture's not sharp, fix it.",
|
13 |
+
"Deblur my picture, it's too fuzzy.",
|
14 |
+
"Help, my photo is too blurry.",
|
15 |
+
"Please, make my image less smudgy."
|
16 |
+
],
|
17 |
+
"dehazing": [
|
18 |
+
"Please, fix the haziness in my image.",
|
19 |
+
"I need to remove the haziness from this image.",
|
20 |
+
"Get rid of the fog in my image.",
|
21 |
+
"Fix my photo, it's too misty.",
|
22 |
+
"Help me, my photo is all hazy."
|
23 |
+
],
|
24 |
+
"deraining": [
|
25 |
+
"I want to eliminate the water from this image.",
|
26 |
+
"Clear the rain from my picture.",
|
27 |
+
"I need to clear the rain from this image.",
|
28 |
+
"Can you get rid of the raindrops in my picture?"
|
29 |
+
],
|
30 |
+
"sr": [
|
31 |
+
"I need to enhance the size and quality of this image.",
|
32 |
+
"My photo is lacking size and clarity; can you improve it?",
|
33 |
+
"I'd appreciate it if you could upscale this photo.",
|
34 |
+
"My picture is too little, enlarge it."
|
35 |
+
],
|
36 |
+
"ambiguous": [
|
37 |
+
"Please, clear up the mess on this image.",
|
38 |
+
"I want this image to look good.",
|
39 |
+
"make it pop",
|
40 |
+
"Fix my photo, it's all messed up."
|
41 |
+
],
|
42 |
+
"lol": [
|
43 |
+
"I took this photo during night, enhance it",
|
44 |
+
"The photo is too dark, improve exposure",
|
45 |
+
"my image has poor lighting conditions, can you fix it?",
|
46 |
+
"Can you make the image brighter?"
|
47 |
+
],
|
48 |
+
"enhancement": [
|
49 |
+
"make my image look like DSLR",
|
50 |
+
"improve the colors of my image",
|
51 |
+
"enhance the colors of the image",
|
52 |
+
"Can you edit this to look like an award-winning photo?",
|
53 |
+
"I want the picture to be retouched for a professional portfolio."
|
54 |
+
]
|
55 |
+
}
|