xilluill commited on
Commit
95d4bb7
·
1 Parent(s): 98a244a

inb version init

Browse files
README.md CHANGED
@@ -1,14 +1,13 @@
1
  ---
2
  title: KV Edit
3
- emoji: 💻
4
  colorFrom: gray
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 5.17.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
- short_description: inversion base version
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: KV Edit
3
+ emoji: 🐢
4
  colorFrom: gray
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 5.16.2
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,7 +1,252 @@
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+ from dataclasses import dataclass
5
+ from glob import iglob
6
+ from einops import rearrange
7
+ from PIL import ExifTags, Image
8
+ import torch
9
  import gradio as gr
10
+ import numpy as np
11
+ from flux.sampling import prepare
12
+ from flux.util import (load_ae, load_clip, load_t5)
13
+ from models.kv_edit import Flux_kv_edit,Flux_kv_edit_inf
14
+ import spaces
15
+ from huggingface_hub import login
16
+ login(token=os.getenv('Token'))
17
 
18
+ @dataclass
19
+ class SamplingOptions:
20
+ source_prompt: str = ''
21
+ target_prompt: str = ''
22
+ # prompt: str
23
+ width: int = 1366
24
+ height: int = 768
25
+ inversion_num_steps: int = 0
26
+ denoise_num_steps: int = 0
27
+ skip_step: int = 0
28
+ inversion_guidance: float = 1.0
29
+ denoise_guidance: float = 1.0
30
+ seed: int = 42
31
+ re_init: bool = False
32
+ attn_mask: bool = False
33
 
34
+ @torch.inference_mode()
35
+ def encode(init_image, torch_device):
36
+ init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
37
+ init_image = init_image.unsqueeze(0)
38
+ init_image = init_image.to(torch_device)
39
+ with torch.no_grad():
40
+ init_image = ae.encode(init_image.to()).to(torch.bfloat16)
41
+ return init_image
42
+
43
+ # init all components
44
+ device = "cuda" if torch.cuda.is_available() else "cpu"
45
+ name = 'flux-dev'
46
+ ae = load_ae(name, device)
47
+ t5 = load_t5(device, max_length=256 if name == "flux-schnell" else 512)
48
+ clip = load_clip(device)
49
+ model = Flux_kv_edit(device=device, name=name)
50
+ offload = False
51
+ name = "flux-dev"
52
+ is_schnell = False
53
+ feature_path = 'feature'
54
+ output_dir = 'result'
55
+ add_sampling_metadata = True
56
+
57
+ @spaces.GPU(duration=120)
58
+ @torch.inference_mode()
59
+ def edit(init_image, brush_canvas,
60
+ source_prompt, target_prompt,
61
+ inversion_num_steps, denoise_num_steps,
62
+ skip_step,
63
+ inversion_guidance, denoise_guidance,seed,
64
+ re_init,attn_mask
65
+ ):
66
+ device = "cuda" if torch.cuda.is_available() else "cpu"
67
+ torch.cuda.empty_cache()
68
+
69
+ shape = init_image.shape
70
+ height = shape[0] if shape[0] % 16 == 0 else shape[0] - shape[0] % 16
71
+ width = shape[1] if shape[1] % 16 == 0 else shape[1] - shape[1] % 16
72
+
73
+ init_image = init_image[:height, :width, :]
74
+ brush_canvas = brush_canvas["composite"][:,:,:3][:height, :width, :]
75
+ # 如果brush_Canvas是三通道黑白图,说明就是输入的mask
76
+
77
+ if np.all(brush_canvas[:,:,0] == brush_canvas[:,:,1]) and np.all(brush_canvas[:,:,1] == brush_canvas[:,:,2]):
78
+ mask = brush_canvas[:,:,0]/255
79
+ mask = mask.astype(int)
80
+ else:
81
+ mask = np.any(init_image != brush_canvas, axis=-1) # 得到一个二维的布尔数组
82
+ mask = mask.astype(int)
83
+ mask_array = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
84
+ mask_array[:,:,0] = mask * 255 # R
85
+ mask_array[:,:,3] = mask * 128 # A (半透明,128表示50%透明度)
86
+ mask_image = Image.fromarray(mask_array, 'RGBA')
87
+ original_image = Image.fromarray(np.concatenate((init_image, np.full((height, width, 1), 255, dtype=np.uint8)), axis=2), 'RGBA')
88
+ masked_image = Image.alpha_composite(original_image, mask_image)
89
+ mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).to(torch.bfloat16).to(device)
90
+
91
+ init_image = encode(init_image, device).to(device)
92
+
93
+ seed = int(seed)
94
+ if seed == -1:
95
+ seed = torch.randint(0, 2**32, (1,)).item()
96
+ opts = SamplingOptions(
97
+ source_prompt=source_prompt,
98
+ target_prompt=target_prompt,
99
+ width=width,
100
+ height=height,
101
+ inversion_num_steps=inversion_num_steps,
102
+ denoise_num_steps=denoise_num_steps,
103
+ skip_step=skip_step,
104
+ inversion_guidance=inversion_guidance,
105
+ denoise_guidance=denoise_guidance,
106
+ seed=seed,
107
+ re_init=re_init,
108
+ attn_mask=attn_mask
109
+ )
110
+
111
+
112
+ torch.manual_seed(opts.seed)
113
+ if torch.cuda.is_available():
114
+ torch.cuda.manual_seed_all(opts.seed)
115
+
116
+ t0 = time.perf_counter()
117
+
118
+ #############inverse#######################
119
+ # 将布尔数组转换为整数类型,如果需要1和0而不是True和False的话
120
+ with torch.no_grad():
121
+ inp = prepare(t5, clip, init_image, prompt=opts.source_prompt)
122
+ inp_target = prepare(t5, clip, init_image, prompt=opts.target_prompt)
123
+
124
+ x = model(inp, inp_target, mask, opts)
125
+
126
+ device = torch.device("cuda")
127
+ with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
128
+ x = ae.decode(x)
129
+ # 得到还在显卡上的特征
130
+ # bring into PIL format and save
131
+ x = x.clamp(-1, 1)
132
+ # x = embed_watermark(x.float())
133
+ x = x.float().cpu()
134
+ x = rearrange(x[0], "c h w -> h w c")
135
+
136
+ if torch.cuda.is_available():
137
+ torch.cuda.synchronize()
138
+ #############回到像素空间就算结束#######################
139
+ output_name = os.path.join(output_dir, "img_{idx}.jpg")
140
+ if not os.path.exists(output_dir):
141
+ os.makedirs(output_dir)
142
+ idx = 0
143
+ else:
144
+ fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
145
+ if len(fns) > 0:
146
+ idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
147
+ else:
148
+ idx = 0
149
+ #############找idx#######################
150
+
151
+ fn = output_name.format(idx=idx)
152
+
153
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
154
+ exif_data = Image.Exif()
155
+ exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
156
+ exif_data[ExifTags.Base.Make] = "Black Forest Labs"
157
+ exif_data[ExifTags.Base.Model] = name
158
+
159
+ exif_data[ExifTags.Base.ImageDescription] = source_prompt
160
+ img.save(fn, exif=exif_data, quality=95, subsampling=0)
161
+ masked_image.save(fn.replace(".jpg", "_mask.png"), format='PNG')
162
+ t1 = time.perf_counter()
163
+ print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
164
+
165
+ print("End Edit")
166
+ return img
167
+
168
+
169
+
170
+ def create_demo(model_name: str):
171
+ # editor = FluxEditor_kv_demo()
172
+ is_schnell = model_name == "flux-schnell"
173
+
174
+ title = r"""
175
+ <h1 align="center">🎨 KV-Edit: Training-Free Image Editing for Precise Background Preservation</h1>
176
+ """
177
+
178
+ description = r"""
179
+ <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/Xilluill/KV-Edit' target='_blank'><b>KV-Edit: Training-Free Image Editing for Precise Background Preservation</b></a>.<br>
180
+
181
+ 🔔🔔[<b>Important</b>] Editing steps:<br>
182
+ 1️⃣ Upload your image that needs to be edited (The resolution is expected be less than 1360*768, or the memory of GPU may be not enough.) <br>
183
+ 2️⃣ Re-upload the original image and use the brush tool to draw your mask area. <br>
184
+ 3️⃣ Fill in your source prompt and target prompt, then adjust the hyperparameters. <br>
185
+ 4️⃣ Click the "Edit" button to generate your edited image! <br>
186
+ """
187
+ article = r"""
188
+ If our work is helpful, please help to ⭐ the <a href='https://github.com/Xilluill/KV-Edit' target='_blank'>Github Repo</a>. Thanks!
189
+ """
190
+
191
+ badge = r"""
192
+ [![GitHub Stars](https://img.shields.io/github/stars/Xilluill/KV-Edit)](https://github.com/Xilluill/KV-Edit)
193
+ """
194
+
195
+ with gr.Blocks() as demo:
196
+ gr.HTML(title)
197
+ gr.Markdown(description)
198
+ gr.Markdown(article)
199
+ # gr.Markdown(badge)
200
+
201
+ with gr.Row():
202
+ with gr.Column():
203
+ source_prompt = gr.Textbox(label="Source Prompt", value='In a cluttered wooden cabin, a workbench holds a green neon sign that reads "I love here"' )
204
+ inversion_num_steps = gr.Slider(1, 50, 28, step=1, label="Number of inversion steps")
205
+ target_prompt = gr.Textbox(label="Target Prompt", value='In a cluttered wooden cabin, a workbench holds a green neon sign that reads "I love iccv"' )
206
+ denoise_num_steps = gr.Slider(1, 50, 28, step=1, label="Number of denoise steps")
207
+ init_image = gr.Image(label="Input Image", visible=True)
208
+ brush_canvas = gr.ImageEditor(label="Brush Canvas",
209
+ sources=('upload'),
210
+ brush=gr.Brush(default_size=10,
211
+ default_color="#000000"),
212
+ interactive=True,
213
+ container=True,
214
+ transforms=[],
215
+ height="auto",
216
+ format='png',scale=1)
217
+
218
+ edit_btn = gr.Button("edit")
219
+
220
+
221
+ with gr.Column():
222
+ with gr.Accordion("Advanced Options", open=True):
223
+ # num_steps = gr.Slider(1, 30, 25, step=1, label="Number of steps")
224
+
225
+ skip_step = gr.Slider(0, 30, 4, step=1, label="Number of inject steps")
226
+ inversion_guidance = gr.Slider(1.0, 10.0, 1.5, step=0.1, label="inversion Guidance", interactive=not is_schnell)
227
+ denoise_guidance = gr.Slider(1.0, 10.0, 5.5, step=0.1, label="denoise Guidance", interactive=not is_schnell)
228
+ seed = gr.Textbox('0', label="Seed (-1 for random)", visible=True)
229
+ with gr.Row():
230
+ re_init = gr.Checkbox(label="re_init", value=False)
231
+ attn_mask = gr.Checkbox(label="attn_mask", value=False)
232
+
233
+
234
+ output_image = gr.Image(label="Generated Image")
235
+ edit_btn.click(
236
+ fn=edit,
237
+ inputs=[init_image, brush_canvas,
238
+ source_prompt, target_prompt,
239
+ inversion_num_steps, denoise_num_steps,
240
+ skip_step,
241
+ inversion_guidance,
242
+ denoise_guidance,seed,
243
+ re_init,attn_mask
244
+ ],
245
+ outputs=[output_image]
246
+ )
247
+ return demo
248
+
249
+
250
+ demo = create_demo("flux-dev")
251
+
252
+ demo.launch()
flux/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from ._version import version as __version__ # type: ignore
3
+ from ._version import version_tuple
4
+ except ImportError:
5
+ __version__ = "unknown (no version information available)"
6
+ version_tuple = (0, 0, "unknown", "noinfo")
7
+
8
+ from pathlib import Path
9
+
10
+ PACKAGE = __package__.replace("_", "-")
11
+ PACKAGE_ROOT = Path(__file__).parent
flux/__main__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .cli import app
2
+
3
+ if __name__ == "__main__":
4
+ app()
flux/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (477 Bytes). View file
 
flux/__pycache__/_version.cpython-310.pyc ADDED
Binary file (485 Bytes). View file
 
flux/__pycache__/math.cpython-310.pyc ADDED
Binary file (2.23 kB). View file
 
flux/__pycache__/model.cpython-310.pyc ADDED
Binary file (4.81 kB). View file
 
flux/__pycache__/sampling.cpython-310.pyc ADDED
Binary file (6.72 kB). View file
 
flux/__pycache__/util.cpython-310.pyc ADDED
Binary file (5.53 kB). View file
 
flux/_version.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file generated by setuptools_scm
2
+ # don't change, don't track in version control
3
+ TYPE_CHECKING = False
4
+ if TYPE_CHECKING:
5
+ from typing import Tuple, Union
6
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
7
+ else:
8
+ VERSION_TUPLE = object
9
+
10
+ version: str
11
+ __version__: str
12
+ __version_tuple__: VERSION_TUPLE
13
+ version_tuple: VERSION_TUPLE
14
+
15
+ __version__ = version = '0.0.post6+ge52e00f.d20250111'
16
+ __version_tuple__ = version_tuple = (0, 0, 'ge52e00f.d20250111')
flux/api.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import time
4
+ from pathlib import Path
5
+
6
+ import requests
7
+ from PIL import Image
8
+
9
+ API_ENDPOINT = "https://api.bfl.ml"
10
+
11
+
12
+ class ApiException(Exception):
13
+ def __init__(self, status_code: int, detail: str | list[dict] | None = None):
14
+ super().__init__()
15
+ self.detail = detail
16
+ self.status_code = status_code
17
+
18
+ def __str__(self) -> str:
19
+ return self.__repr__()
20
+
21
+ def __repr__(self) -> str:
22
+ if self.detail is None:
23
+ message = None
24
+ elif isinstance(self.detail, str):
25
+ message = self.detail
26
+ else:
27
+ message = "[" + ",".join(d["msg"] for d in self.detail) + "]"
28
+ return f"ApiException({self.status_code=}, {message=}, detail={self.detail})"
29
+
30
+
31
+ class ImageRequest:
32
+ def __init__(
33
+ self,
34
+ prompt: str,
35
+ width: int = 1024,
36
+ height: int = 1024,
37
+ name: str = "flux.1-pro",
38
+ num_steps: int = 50,
39
+ prompt_upsampling: bool = False,
40
+ seed: int | None = None,
41
+ validate: bool = True,
42
+ launch: bool = True,
43
+ api_key: str | None = None,
44
+ ):
45
+ """
46
+ Manages an image generation request to the API.
47
+
48
+ Args:
49
+ prompt: Prompt to sample
50
+ width: Width of the image in pixel
51
+ height: Height of the image in pixel
52
+ name: Name of the model
53
+ num_steps: Number of network evaluations
54
+ prompt_upsampling: Use prompt upsampling
55
+ seed: Fix the generation seed
56
+ validate: Run input validation
57
+ launch: Directly launches request
58
+ api_key: Your API key if not provided by the environment
59
+
60
+ Raises:
61
+ ValueError: For invalid input
62
+ ApiException: For errors raised from the API
63
+ """
64
+ if validate:
65
+ if name not in ["flux.1-pro"]:
66
+ raise ValueError(f"Invalid model {name}")
67
+ elif width % 32 != 0:
68
+ raise ValueError(f"width must be divisible by 32, got {width}")
69
+ elif not (256 <= width <= 1440):
70
+ raise ValueError(f"width must be between 256 and 1440, got {width}")
71
+ elif height % 32 != 0:
72
+ raise ValueError(f"height must be divisible by 32, got {height}")
73
+ elif not (256 <= height <= 1440):
74
+ raise ValueError(f"height must be between 256 and 1440, got {height}")
75
+ elif not (1 <= num_steps <= 50):
76
+ raise ValueError(f"steps must be between 1 and 50, got {num_steps}")
77
+
78
+ self.request_json = {
79
+ "prompt": prompt,
80
+ "width": width,
81
+ "height": height,
82
+ "variant": name,
83
+ "steps": num_steps,
84
+ "prompt_upsampling": prompt_upsampling,
85
+ }
86
+ if seed is not None:
87
+ self.request_json["seed"] = seed
88
+
89
+ self.request_id: str | None = None
90
+ self.result: dict | None = None
91
+ self._image_bytes: bytes | None = None
92
+ self._url: str | None = None
93
+ if api_key is None:
94
+ self.api_key = os.environ.get("BFL_API_KEY")
95
+ else:
96
+ self.api_key = api_key
97
+
98
+ if launch:
99
+ self.request()
100
+
101
+ def request(self):
102
+ """
103
+ Request to generate the image.
104
+ """
105
+ if self.request_id is not None:
106
+ return
107
+ response = requests.post(
108
+ f"{API_ENDPOINT}/v1/image",
109
+ headers={
110
+ "accept": "application/json",
111
+ "x-key": self.api_key,
112
+ "Content-Type": "application/json",
113
+ },
114
+ json=self.request_json,
115
+ )
116
+ result = response.json()
117
+ if response.status_code != 200:
118
+ raise ApiException(status_code=response.status_code, detail=result.get("detail"))
119
+ self.request_id = response.json()["id"]
120
+
121
+ def retrieve(self) -> dict:
122
+ """
123
+ Wait for the generation to finish and retrieve response.
124
+ """
125
+ if self.request_id is None:
126
+ self.request()
127
+ while self.result is None:
128
+ response = requests.get(
129
+ f"{API_ENDPOINT}/v1/get_result",
130
+ headers={
131
+ "accept": "application/json",
132
+ "x-key": self.api_key,
133
+ },
134
+ params={
135
+ "id": self.request_id,
136
+ },
137
+ )
138
+ result = response.json()
139
+ if "status" not in result:
140
+ raise ApiException(status_code=response.status_code, detail=result.get("detail"))
141
+ elif result["status"] == "Ready":
142
+ self.result = result["result"]
143
+ elif result["status"] == "Pending":
144
+ time.sleep(0.5)
145
+ else:
146
+ raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")
147
+ return self.result
148
+
149
+ @property
150
+ def bytes(self) -> bytes:
151
+ """
152
+ Generated image as bytes.
153
+ """
154
+ if self._image_bytes is None:
155
+ response = requests.get(self.url)
156
+ if response.status_code == 200:
157
+ self._image_bytes = response.content
158
+ else:
159
+ raise ApiException(status_code=response.status_code)
160
+ return self._image_bytes
161
+
162
+ @property
163
+ def url(self) -> str:
164
+ """
165
+ Public url to retrieve the image from
166
+ """
167
+ if self._url is None:
168
+ result = self.retrieve()
169
+ self._url = result["sample"]
170
+ return self._url
171
+
172
+ @property
173
+ def image(self) -> Image.Image:
174
+ """
175
+ Load the image as a PIL Image
176
+ """
177
+ return Image.open(io.BytesIO(self.bytes))
178
+
179
+ def save(self, path: str):
180
+ """
181
+ Save the generated image to a local path
182
+ """
183
+ suffix = Path(self.url).suffix
184
+ if not path.endswith(suffix):
185
+ path = path + suffix
186
+ Path(path).resolve().parent.mkdir(parents=True, exist_ok=True)
187
+ with open(path, "wb") as file:
188
+ file.write(self.bytes)
189
+
190
+
191
+ if __name__ == "__main__":
192
+ from fire import Fire
193
+
194
+ Fire(ImageRequest)
flux/math.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import Tensor
4
+
5
+
6
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor,pe_q = None, attention_mask = None) -> Tensor:
7
+ if pe_q is None:
8
+ q, k = apply_rope(q, k, pe) # torch.Size([1, 24, 4592, 128]) torch.Size([1, 24, 4592, 128]) pe torch.Size([1, 1, 4592, 64, 2, 2])
9
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v,attn_mask=attention_mask) # torch.Size([1, 24, 4592, 128])
10
+ x = rearrange(x, "B H L D -> B L (H D)") # torch.Size([1, 4592, 3072])
11
+ return x
12
+ else:
13
+ q, k = apply_rope_qk(q, k, pe_q, pe) # torch.Size([1, 24, 4592, 128]) torch.Size([1, 24, 4592, 128]) pe torch.Size([1, 1, 4592, 64, 2, 2])
14
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v,attn_mask=attention_mask)
15
+ x = rearrange(x, "B H L D -> B L (H D)")
16
+ return x
17
+
18
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
19
+ assert dim % 2 == 0
20
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim # dim =16 + 56 + 56
21
+ omega = 1.0 / (theta**scale) # 64 omega
22
+ out = torch.einsum("...n,d->...nd", pos, omega)
23
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
24
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) # torch.Size([1, 1, 4592, x, 2, 2]) x = 8 + 28 + 28
25
+ return out.float()
26
+
27
+
28
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
29
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) # torch.Size([1, 24, 4592, 128]) -> torch.Size([1, 24, 4592, 64, 1, 2])
30
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) # torch.Size([1, 24, 4592, 128]) -> torch.Size([1, 24, 4592, 64, 1, 2])
31
+ xq_out = freqs_cis[:, :, :xq_.shape[2], :, :, 0] * xq_[..., 0] + freqs_cis[:, :, :xq_.shape[2], :, :, 1] * xq_[..., 1]
32
+ xk_out = freqs_cis[:, :, :xk_.shape[2], :, :, 0] * xk_[..., 0] + freqs_cis[:, :, :xk_.shape[2], :, :, 1] * xk_[..., 1]
33
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
34
+
35
+ def apply_rope_qk(xq: Tensor, xk: Tensor, freqs_cis_q: Tensor,freqs_cis_k: Tensor) -> tuple[Tensor, Tensor]:
36
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) # torch.Size([1, 24, 4592, 128]) -> torch.Size([1, 24, 4592, 64, 1, 2])
37
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) # torch.Size([1, 24, 4592, 128]) -> torch.Size([1, 24, 4592, 64, 1, 2])
38
+ xq_out = freqs_cis_q[:, :, :xq_.shape[2], :, :, 0] * xq_[..., 0] + freqs_cis_q[:, :, :xq_.shape[2], :, :, 1] * xq_[..., 1]
39
+ xk_out = freqs_cis_k[:, :, :xk_.shape[2], :, :, 0] * xk_[..., 0] + freqs_cis_k[:, :, :xk_.shape[2], :, :, 1] * xk_[..., 1]
40
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
flux/model.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
7
+ MLPEmbedder, SingleStreamBlock,DoubleStreamBlock_rf,SingleStreamBlock_rf,
8
+ SingleStreamBlock_kv,DoubleStreamBlock_kv,
9
+ timestep_embedding)
10
+
11
+
12
+ @dataclass
13
+ class FluxParams:
14
+ in_channels: int
15
+ vec_in_dim: int
16
+ context_in_dim: int
17
+ hidden_size: int
18
+ mlp_ratio: float
19
+ num_heads: int
20
+ depth: int
21
+ depth_single_blocks: int
22
+ axes_dim: list[int]
23
+ theta: int
24
+ qkv_bias: bool
25
+ guidance_embed: bool
26
+
27
+
28
+ class Flux(nn.Module):
29
+ """
30
+ Transformer model for flow matching on sequences.
31
+ """
32
+
33
+ def __init__(self, params: FluxParams,double_block_cls=DoubleStreamBlock,single_block_cls=SingleStreamBlock):
34
+ super().__init__()
35
+
36
+ self.params = params
37
+ self.in_channels = params.in_channels
38
+ self.out_channels = self.in_channels
39
+ if params.hidden_size % params.num_heads != 0:
40
+ raise ValueError(
41
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
42
+ )
43
+ pe_dim = params.hidden_size // params.num_heads
44
+ if sum(params.axes_dim) != pe_dim:
45
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
46
+ self.hidden_size = params.hidden_size
47
+ self.num_heads = params.num_heads
48
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
49
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
50
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
51
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
52
+ self.guidance_in = (
53
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
54
+ )
55
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
56
+
57
+ self.double_blocks = nn.ModuleList(
58
+ [
59
+ double_block_cls(
60
+ self.hidden_size,
61
+ self.num_heads,
62
+ mlp_ratio=params.mlp_ratio,
63
+ qkv_bias=params.qkv_bias,
64
+ )
65
+ for _ in range(params.depth)
66
+ ]
67
+ )
68
+
69
+ self.single_blocks = nn.ModuleList(
70
+ [
71
+ single_block_cls(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
72
+ for _ in range(params.depth_single_blocks)
73
+ ]
74
+ )
75
+
76
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
77
+
78
+ def forward(
79
+ self,
80
+ img: Tensor,
81
+ img_ids: Tensor,
82
+ txt: Tensor,
83
+ txt_ids: Tensor,
84
+ timesteps: Tensor,
85
+ y: Tensor,
86
+ guidance: Tensor | None = None,
87
+ ) -> Tensor:
88
+ if img.ndim != 3 or txt.ndim != 3:
89
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
90
+
91
+ # running on sequences img
92
+ img = self.img_in(img)
93
+ vec = self.time_in(timestep_embedding(timesteps, 256))
94
+ if self.params.guidance_embed:
95
+ if guidance is None:
96
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
97
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
98
+ vec = vec + self.vector_in(y)
99
+ txt = self.txt_in(txt)
100
+
101
+ ids = torch.cat((txt_ids, img_ids), dim=1)
102
+ pe = self.pe_embedder(ids)
103
+
104
+ for block in self.double_blocks:
105
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
106
+
107
+ img = torch.cat((txt, img), 1)
108
+ for block in self.single_blocks:
109
+ img = block(img, vec=vec, pe=pe)
110
+ img = img[:, txt.shape[1] :, ...]
111
+
112
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
113
+ return img
114
+
115
+ class Flux_kv(Flux):
116
+ """
117
+ 继承Flux类,重写forward方法
118
+ """
119
+
120
+ def __init__(self, params: FluxParams,double_block_cls=DoubleStreamBlock_kv,single_block_cls=SingleStreamBlock_kv):
121
+ super().__init__(params,double_block_cls,single_block_cls)
122
+
123
+ def forward(
124
+ self,
125
+ img: Tensor, # (B,x,x) (1,4080,64)
126
+ img_ids: Tensor,
127
+ txt: Tensor, # torch.Size([1, 512, 4096])
128
+ txt_ids: Tensor,
129
+ timesteps: Tensor, # torch.Size([1])
130
+ y: Tensor, # torch.Size([1, 768])
131
+ guidance: Tensor | None = None, # torch.Size([1])
132
+ info: dict = {},
133
+ ) -> Tensor:
134
+ if img.ndim != 3 or txt.ndim != 3:
135
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
136
+
137
+ # running on sequences img
138
+ img = self.img_in(img)
139
+ vec = self.time_in(timestep_embedding(timesteps, 256)) # torch.Size([1, 3072])
140
+ if self.params.guidance_embed:
141
+ if guidance is None:
142
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
143
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) # torch.Size([1, 3072])
144
+ vec = vec + self.vector_in(y)# torch.Size([1, 3072])
145
+ txt = self.txt_in(txt) # ([1, 512, 4096]) -> torch.Size([1, 512, 3072])
146
+
147
+ ids = torch.cat((txt_ids, img_ids), dim=1) # torch.Size([1, 512, 3]) torch.Size([1, 4080, 3]) -> torch.Size([1, 4592, 3])
148
+ pe = self.pe_embedder(ids) # torch.Size([1, 1, 4592, 64, 2, 2])
149
+ if not info['inverse']:
150
+ mask_indices = info['mask_indices'] # 图片seq坐标下的
151
+ info['pe_mask'] = torch.cat((pe[:, :, :512, ...],pe[:, :, mask_indices+512, ...]),dim=2)
152
+
153
+ cnt = 0
154
+ for block in self.double_blocks:
155
+ info['id'] = cnt
156
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe, info=info)
157
+ cnt += 1
158
+
159
+ cnt = 0
160
+ x = torch.cat((txt, img), 1)
161
+ for block in self.single_blocks:
162
+ info['id'] = cnt
163
+ x = block(x, vec=vec, pe=pe, info=info)
164
+ cnt += 1
165
+
166
+ img = x[:, txt.shape[1] :, ...]
167
+
168
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
169
+
170
+ return img
flux/modules/__pycache__/autoencoder.cpython-310.pyc ADDED
Binary file (9.03 kB). View file
 
flux/modules/__pycache__/conditioner.cpython-310.pyc ADDED
Binary file (1.47 kB). View file
 
flux/modules/__pycache__/layers.cpython-310.pyc ADDED
Binary file (15.9 kB). View file
 
flux/modules/autoencoder.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from torch import Tensor, nn
6
+
7
+
8
+ @dataclass
9
+ class AutoEncoderParams:
10
+ resolution: int
11
+ in_channels: int
12
+ ch: int
13
+ out_ch: int
14
+ ch_mult: list[int]
15
+ num_res_blocks: int
16
+ z_channels: int
17
+ scale_factor: float
18
+ shift_factor: float
19
+
20
+
21
+ def swish(x: Tensor) -> Tensor:
22
+ return x * torch.sigmoid(x)
23
+
24
+
25
+ class AttnBlock(nn.Module):
26
+ def __init__(self, in_channels: int):
27
+ super().__init__()
28
+ self.in_channels = in_channels
29
+
30
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
31
+
32
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
33
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
34
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
35
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
36
+
37
+ def attention(self, h_: Tensor) -> Tensor:
38
+ h_ = self.norm(h_)
39
+ q = self.q(h_)
40
+ k = self.k(h_)
41
+ v = self.v(h_)
42
+
43
+ b, c, h, w = q.shape
44
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
45
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
46
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
47
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
48
+
49
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
50
+
51
+ def forward(self, x: Tensor) -> Tensor:
52
+ return x + self.proj_out(self.attention(x))
53
+
54
+
55
+ class ResnetBlock(nn.Module):
56
+ def __init__(self, in_channels: int, out_channels: int):
57
+ super().__init__()
58
+ self.in_channels = in_channels
59
+ out_channels = in_channels if out_channels is None else out_channels
60
+ self.out_channels = out_channels
61
+
62
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
63
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
64
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
65
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
66
+ if self.in_channels != self.out_channels:
67
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
68
+
69
+ def forward(self, x):
70
+ h = x
71
+ h = self.norm1(h)
72
+ h = swish(h)
73
+ h = self.conv1(h)
74
+
75
+ h = self.norm2(h)
76
+ h = swish(h)
77
+ h = self.conv2(h)
78
+
79
+ if self.in_channels != self.out_channels:
80
+ x = self.nin_shortcut(x)
81
+
82
+ return x + h
83
+
84
+
85
+ class Downsample(nn.Module):
86
+ def __init__(self, in_channels: int):
87
+ super().__init__()
88
+ # no asymmetric padding in torch conv, must do it ourselves
89
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
90
+
91
+ def forward(self, x: Tensor):
92
+ pad = (0, 1, 0, 1)
93
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
94
+ x = self.conv(x)
95
+ return x
96
+
97
+
98
+ class Upsample(nn.Module):
99
+ def __init__(self, in_channels: int):
100
+ super().__init__()
101
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
102
+
103
+ def forward(self, x: Tensor):
104
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
105
+ x = self.conv(x)
106
+ return x
107
+
108
+
109
+ class Encoder(nn.Module):
110
+ def __init__(
111
+ self,
112
+ resolution: int,
113
+ in_channels: int,
114
+ ch: int,
115
+ ch_mult: list[int],
116
+ num_res_blocks: int,
117
+ z_channels: int,
118
+ ):
119
+ super().__init__()
120
+ self.ch = ch
121
+ self.num_resolutions = len(ch_mult)
122
+ self.num_res_blocks = num_res_blocks
123
+ self.resolution = resolution
124
+ self.in_channels = in_channels
125
+ # downsampling
126
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
127
+
128
+ curr_res = resolution
129
+ in_ch_mult = (1,) + tuple(ch_mult)
130
+ self.in_ch_mult = in_ch_mult
131
+ self.down = nn.ModuleList()
132
+ block_in = self.ch
133
+ for i_level in range(self.num_resolutions):
134
+ block = nn.ModuleList()
135
+ attn = nn.ModuleList()
136
+ block_in = ch * in_ch_mult[i_level]
137
+ block_out = ch * ch_mult[i_level]
138
+ for _ in range(self.num_res_blocks):
139
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
140
+ block_in = block_out
141
+ down = nn.Module()
142
+ down.block = block
143
+ down.attn = attn
144
+ if i_level != self.num_resolutions - 1:
145
+ down.downsample = Downsample(block_in)
146
+ curr_res = curr_res // 2
147
+ self.down.append(down)
148
+
149
+ # middle
150
+ self.mid = nn.Module()
151
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
152
+ self.mid.attn_1 = AttnBlock(block_in)
153
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
154
+
155
+ # end
156
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
157
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
158
+
159
+ def forward(self, x: Tensor) -> Tensor:
160
+ # downsampling
161
+ hs = [self.conv_in(x)]
162
+ for i_level in range(self.num_resolutions):
163
+ for i_block in range(self.num_res_blocks):
164
+ h = self.down[i_level].block[i_block](hs[-1])
165
+ if len(self.down[i_level].attn) > 0:
166
+ h = self.down[i_level].attn[i_block](h)
167
+ hs.append(h)
168
+ if i_level != self.num_resolutions - 1:
169
+ hs.append(self.down[i_level].downsample(hs[-1]))
170
+
171
+ # middle
172
+ h = hs[-1]
173
+ h = self.mid.block_1(h)
174
+ h = self.mid.attn_1(h)
175
+ h = self.mid.block_2(h)
176
+ # end
177
+ h = self.norm_out(h)
178
+ h = swish(h)
179
+ h = self.conv_out(h)
180
+ return h
181
+
182
+
183
+ class Decoder(nn.Module):
184
+ def __init__(
185
+ self,
186
+ ch: int,
187
+ out_ch: int,
188
+ ch_mult: list[int],
189
+ num_res_blocks: int,
190
+ in_channels: int,
191
+ resolution: int,
192
+ z_channels: int,
193
+ ):
194
+ super().__init__()
195
+ self.ch = ch
196
+ self.num_resolutions = len(ch_mult)
197
+ self.num_res_blocks = num_res_blocks
198
+ self.resolution = resolution
199
+ self.in_channels = in_channels
200
+ self.ffactor = 2 ** (self.num_resolutions - 1)
201
+
202
+ # compute in_ch_mult, block_in and curr_res at lowest res
203
+ block_in = ch * ch_mult[self.num_resolutions - 1]
204
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
205
+ self.z_shape = (1, z_channels, curr_res, curr_res)
206
+
207
+ # z to block_in
208
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
209
+
210
+ # middle
211
+ self.mid = nn.Module()
212
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
213
+ self.mid.attn_1 = AttnBlock(block_in)
214
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
215
+
216
+ # upsampling
217
+ self.up = nn.ModuleList()
218
+ for i_level in reversed(range(self.num_resolutions)):
219
+ block = nn.ModuleList()
220
+ attn = nn.ModuleList()
221
+ block_out = ch * ch_mult[i_level]
222
+ for _ in range(self.num_res_blocks + 1):
223
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
224
+ block_in = block_out
225
+ up = nn.Module()
226
+ up.block = block
227
+ up.attn = attn
228
+ if i_level != 0:
229
+ up.upsample = Upsample(block_in)
230
+ curr_res = curr_res * 2
231
+ self.up.insert(0, up) # prepend to get consistent order
232
+
233
+ # end
234
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
235
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
236
+
237
+ def forward(self, z: Tensor) -> Tensor:
238
+ # z to block_in
239
+ h = self.conv_in(z)
240
+
241
+ # middle
242
+ h = self.mid.block_1(h)
243
+ h = self.mid.attn_1(h)
244
+ h = self.mid.block_2(h)
245
+
246
+ # upsampling
247
+ for i_level in reversed(range(self.num_resolutions)):
248
+ for i_block in range(self.num_res_blocks + 1):
249
+ h = self.up[i_level].block[i_block](h)
250
+ if len(self.up[i_level].attn) > 0:
251
+ h = self.up[i_level].attn[i_block](h)
252
+ if i_level != 0:
253
+ h = self.up[i_level].upsample(h)
254
+
255
+ # end
256
+ h = self.norm_out(h)
257
+ h = swish(h)
258
+ h = self.conv_out(h)
259
+ return h
260
+
261
+
262
+ class DiagonalGaussian(nn.Module):
263
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
264
+ super().__init__()
265
+ self.sample = sample
266
+ self.chunk_dim = chunk_dim
267
+
268
+ def forward(self, z: Tensor) -> Tensor:
269
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
270
+ # import pdb;pdb.set_trace()
271
+ if self.sample:
272
+ std = torch.exp(0.5 * logvar)
273
+ return mean #+ std * torch.randn_like(mean)
274
+ else:
275
+ return mean
276
+
277
+
278
+ class AutoEncoder(nn.Module):
279
+ def __init__(self, params: AutoEncoderParams):
280
+ super().__init__()
281
+ self.encoder = Encoder(
282
+ resolution=params.resolution,
283
+ in_channels=params.in_channels,
284
+ ch=params.ch,
285
+ ch_mult=params.ch_mult,
286
+ num_res_blocks=params.num_res_blocks,
287
+ z_channels=params.z_channels,
288
+ )
289
+ self.decoder = Decoder(
290
+ resolution=params.resolution,
291
+ in_channels=params.in_channels,
292
+ ch=params.ch,
293
+ out_ch=params.out_ch,
294
+ ch_mult=params.ch_mult,
295
+ num_res_blocks=params.num_res_blocks,
296
+ z_channels=params.z_channels,
297
+ )
298
+ self.reg = DiagonalGaussian()
299
+
300
+ self.scale_factor = params.scale_factor
301
+ self.shift_factor = params.shift_factor
302
+
303
+ def encode(self, x: Tensor) -> Tensor:
304
+ z = self.reg(self.encoder(x))
305
+ z = self.scale_factor * (z - self.shift_factor)
306
+ return z
307
+
308
+ def decode(self, z: Tensor) -> Tensor:
309
+ z = z / self.scale_factor + self.shift_factor
310
+ return self.decoder(z)
311
+
312
+ def forward(self, x: Tensor) -> Tensor:
313
+ return self.decode(self.encode(x))
flux/modules/conditioner.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor, nn
2
+ from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
3
+ T5Tokenizer)
4
+
5
+
6
+ class HFEmbedder(nn.Module):
7
+ def __init__(self, version: str, max_length: int, is_clip, **hf_kwargs):
8
+ super().__init__()
9
+ self.is_clip = is_clip
10
+ self.max_length = max_length
11
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
12
+
13
+ if self.is_clip:
14
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
15
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
16
+ else:
17
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
18
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
19
+
20
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
21
+
22
+ def forward(self, text: list[str]) -> Tensor:
23
+ batch_encoding = self.tokenizer(
24
+ text,
25
+ truncation=True,
26
+ max_length=self.max_length,
27
+ return_length=False,
28
+ return_overflowing_tokens=False,
29
+ padding="max_length",
30
+ return_tensors="pt",
31
+ )
32
+
33
+ outputs = self.hf_module(
34
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
35
+ attention_mask=None,
36
+ output_hidden_states=False,
37
+ )
38
+ return outputs[self.output_key]
flux/modules/layers.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from torch import Tensor, nn
7
+
8
+ from flux.math import attention, rope,apply_rope
9
+
10
+ import os
11
+
12
+ class EmbedND(nn.Module):
13
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
14
+ super().__init__()
15
+ self.dim = dim
16
+ self.theta = theta
17
+ self.axes_dim = axes_dim
18
+
19
+ def forward(self, ids: Tensor) -> Tensor:
20
+ n_axes = ids.shape[-1]
21
+ emb = torch.cat(
22
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], # [1, 1, 4592, 8, 2, 2]) [1, 1, 4592, 28, 2, 2]) [1, 1, 4592, 28, 2, 2])
23
+ dim=-3,
24
+ )
25
+
26
+ return emb.unsqueeze(1)
27
+
28
+
29
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
30
+ """
31
+ Create sinusoidal timestep embeddings.
32
+ :param t: a 1-D Tensor of N indices, one per batch element.
33
+ These may be fractional.
34
+ :param dim: the dimension of the output.
35
+ :param max_period: controls the minimum frequency of the embeddings.
36
+ :return: an (N, D) Tensor of positional embeddings.
37
+ """
38
+ t = time_factor * t
39
+ half = dim // 2
40
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
41
+ t.device
42
+ )
43
+
44
+ args = t[:, None].float() * freqs[None]
45
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
46
+ if dim % 2:
47
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
48
+ if torch.is_floating_point(t):
49
+ embedding = embedding.to(t)
50
+ return embedding
51
+
52
+
53
+ class MLPEmbedder(nn.Module):
54
+ def __init__(self, in_dim: int, hidden_dim: int):
55
+ super().__init__()
56
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
57
+ self.silu = nn.SiLU()
58
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
59
+
60
+ def forward(self, x: Tensor) -> Tensor:
61
+ return self.out_layer(self.silu(self.in_layer(x)))
62
+
63
+
64
+ class RMSNorm(torch.nn.Module):
65
+ def __init__(self, dim: int):
66
+ super().__init__()
67
+ self.scale = nn.Parameter(torch.ones(dim))
68
+
69
+ def forward(self, x: Tensor):
70
+ x_dtype = x.dtype
71
+ x = x.float()
72
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
73
+ return (x * rrms).to(dtype=x_dtype) * self.scale
74
+
75
+
76
+ class QKNorm(torch.nn.Module):
77
+ def __init__(self, dim: int):
78
+ super().__init__()
79
+ self.query_norm = RMSNorm(dim)
80
+ self.key_norm = RMSNorm(dim)
81
+
82
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
83
+ q = self.query_norm(q)
84
+ k = self.key_norm(k)
85
+ return q.to(v), k.to(v)
86
+
87
+
88
+ class SelfAttention(nn.Module):
89
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
90
+ super().__init__()
91
+ self.num_heads = num_heads
92
+ head_dim = dim // num_heads
93
+
94
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
95
+ self.norm = QKNorm(head_dim)
96
+ self.proj = nn.Linear(dim, dim)
97
+
98
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
99
+ qkv = self.qkv(x)
100
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
101
+ q, k = self.norm(q, k, v)
102
+ x = attention(q, k, v, pe=pe)
103
+ x = self.proj(x)
104
+ return x
105
+
106
+
107
+ @dataclass
108
+ class ModulationOut:
109
+ shift: Tensor
110
+ scale: Tensor
111
+ gate: Tensor
112
+
113
+
114
+ class Modulation(nn.Module):
115
+ def __init__(self, dim: int, double: bool):
116
+ super().__init__()
117
+ self.is_double = double
118
+ self.multiplier = 6 if double else 3
119
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
120
+
121
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
122
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
123
+
124
+ return (
125
+ ModulationOut(*out[:3]),
126
+ ModulationOut(*out[3:]) if self.is_double else None,
127
+ )
128
+
129
+
130
+ class DoubleStreamBlock(nn.Module):
131
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
132
+ super().__init__()
133
+
134
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
135
+ self.num_heads = num_heads
136
+ self.hidden_size = hidden_size
137
+ self.img_mod = Modulation(hidden_size, double=True)
138
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
139
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
140
+
141
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
142
+ self.img_mlp = nn.Sequential(
143
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
144
+ nn.GELU(approximate="tanh"),
145
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
146
+ )
147
+
148
+ self.txt_mod = Modulation(hidden_size, double=True)
149
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
150
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
151
+
152
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
153
+ self.txt_mlp = nn.Sequential(
154
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
155
+ nn.GELU(approximate="tanh"),
156
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
157
+ )
158
+
159
+ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
160
+ img_mod1, img_mod2 = self.img_mod(vec)
161
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
162
+
163
+ # prepare image for attention
164
+ img_modulated = self.img_norm1(img)
165
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
166
+ img_qkv = self.img_attn.qkv(img_modulated)
167
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) # torch.Size([1, 24, 4080, 128])
168
+
169
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
170
+ # prepare txt for attention
171
+ txt_modulated = self.txt_norm1(txt)
172
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
173
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
174
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
175
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
176
+
177
+ # run actual attention
178
+ q = torch.cat((txt_q, img_q), dim=2) # [8, 24, 512, 128] + [8, 24, 900, 128] -> [8, 24, 1412, 128]
179
+ k = torch.cat((txt_k, img_k), dim=2)
180
+ v = torch.cat((txt_v, img_v), dim=2)
181
+ # import pdb;pdb.set_trace()
182
+ attn = attention(q, k, v, pe=pe)
183
+
184
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
185
+
186
+ # calculate the img bloks
187
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
188
+ img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
189
+
190
+ # calculate the txt bloks
191
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
192
+ txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
193
+ return img, txt
194
+ class SingleStreamBlock(nn.Module):
195
+ """
196
+ A DiT block with parallel linear layers as described in
197
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
198
+ """
199
+
200
+ def __init__(
201
+ self,
202
+ hidden_size: int,
203
+ num_heads: int,
204
+ mlp_ratio: float = 4.0,
205
+ qk_scale: float | None = None,
206
+ ):
207
+ super().__init__()
208
+ self.hidden_dim = hidden_size
209
+ self.num_heads = num_heads
210
+ head_dim = hidden_size // num_heads
211
+ self.scale = qk_scale or head_dim**-0.5
212
+
213
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
214
+ # qkv and mlp_in
215
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
216
+ # proj and mlp_out
217
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
218
+
219
+ self.norm = QKNorm(head_dim)
220
+
221
+ self.hidden_size = hidden_size
222
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
223
+
224
+ self.mlp_act = nn.GELU(approximate="tanh")
225
+ self.modulation = Modulation(hidden_size, double=False)
226
+
227
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
228
+ mod, _ = self.modulation(vec)
229
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
230
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
231
+
232
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
233
+ q, k = self.norm(q, k, v)
234
+
235
+ # compute attention
236
+ attn = attention(q, k, v, pe=pe)
237
+ # compute activation in mlp stream, cat again and run second linear layer
238
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
239
+ return x + mod.gate * output
240
+
241
+
242
+ class LastLayer(nn.Module):
243
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
244
+ super().__init__()
245
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
246
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
247
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
248
+
249
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
250
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
251
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
252
+ x = self.linear(x)
253
+ return x
254
+
255
+
256
+ class LastLayer(nn.Module):
257
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
258
+ super().__init__()
259
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
260
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
261
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
262
+
263
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
264
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
265
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
266
+ x = self.linear(x)
267
+ return x
268
+
269
+
270
+ class DoubleStreamBlock_kv(DoubleStreamBlock):
271
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
272
+ super().__init__(hidden_size, num_heads, mlp_ratio, qkv_bias)
273
+
274
+ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, info) -> tuple[Tensor, Tensor]:
275
+ img_mod1, img_mod2 = self.img_mod(vec)
276
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
277
+
278
+ # prepare image for attention
279
+ img_modulated = self.img_norm1(img)
280
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
281
+ img_qkv = self.img_attn.qkv(img_modulated)
282
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) # torch.Size([1, 24, 4080, 128])
283
+
284
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
285
+ # prepare txt for attention
286
+ txt_modulated = self.txt_norm1(txt)
287
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
288
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
289
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
290
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
291
+
292
+ feature_k_name = str(info['t']) + '_' + str(info['id']) + '_' + 'MB' + '_' + 'K'
293
+ feature_v_name = str(info['t']) + '_' + str(info['id']) + '_' + 'MB' + '_' + 'V'
294
+ if info['inverse']:
295
+ info['feature'][feature_k_name] = img_k.cpu()
296
+ info['feature'][feature_v_name] = img_v.cpu()
297
+ q = torch.cat((txt_q, img_q), dim=2) # [B, 24, 512, 128] + [B, 24, 900, 128] -> [B, 24, 1412, 128]
298
+ k = torch.cat((txt_k, img_k), dim=2)
299
+ v = torch.cat((txt_v, img_v), dim=2)
300
+ if 'attention_mask' in info:
301
+ attn = attention(q, k, v, pe=pe,attention_mask=info['attention_mask'])
302
+ else:
303
+ attn = attention(q, k, v, pe=pe)
304
+
305
+ # elif feature_k_name in info['feature']:
306
+ else:
307
+ source_img_k = info['feature'][feature_k_name].to(img.device)
308
+ source_img_v = info['feature'][feature_v_name].to(img.device)
309
+
310
+ mask_indices = info['mask_indices'] # 图片seq坐标下的
311
+ source_img_k[:, :, mask_indices, ...] = img_k
312
+ source_img_v[:, :, mask_indices, ...] = img_v
313
+
314
+ q = torch.cat((txt_q, img_q), dim=2)
315
+ k = torch.cat((txt_k, source_img_k), dim=2)
316
+ v = torch.cat((txt_v, source_img_v), dim=2)
317
+ attn = attention(q, k, v, pe=pe, pe_q = info['pe_mask'])
318
+
319
+
320
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
321
+
322
+ # calculate the img bloks
323
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
324
+ img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
325
+
326
+ # calculate the txt bloks
327
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
328
+ txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
329
+ return img, txt
330
+
331
+ class SingleStreamBlock_kv(SingleStreamBlock):
332
+ """
333
+ A DiT block with parallel linear layers as described in
334
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
335
+ """
336
+
337
+ def __init__(
338
+ self,
339
+ hidden_size: int,
340
+ num_heads: int,
341
+ mlp_ratio: float = 4.0,
342
+ qk_scale: float | None = None,
343
+ ):
344
+ super().__init__(hidden_size, num_heads, mlp_ratio, qk_scale)
345
+
346
+ def forward(self,x: Tensor, vec: Tensor, pe: Tensor, info) -> Tensor:
347
+ mod, _ = self.modulation(vec)
348
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
349
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
350
+
351
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
352
+ q, k = self.norm(q, k, v)
353
+ img_k = k[:, :, 512:, ...]
354
+ img_v = v[:, :, 512:, ...]
355
+
356
+ txt_k = k[:, :, :512, ...]
357
+ txt_v = v[:, :, :512, ...]
358
+
359
+
360
+ feature_k_name = str(info['t']) + '_' + str(info['id']) + '_' + 'SB' + '_' + 'K'
361
+ feature_v_name = str(info['t']) + '_' + str(info['id']) + '_' + 'SB' + '_' + 'V'
362
+ if info['inverse']:
363
+ info['feature'][feature_k_name] = img_k.cpu()
364
+ info['feature'][feature_v_name] = img_v.cpu()
365
+ if 'attention_mask' in info:
366
+ attn = attention(q, k, v, pe=pe,attention_mask=info['attention_mask'])
367
+ else:
368
+ attn = attention(q, k, v, pe=pe)
369
+
370
+ else:
371
+ source_img_k = info['feature'][feature_k_name].to(x.device)
372
+ source_img_v = info['feature'][feature_v_name].to(x.device)
373
+
374
+ mask_indices = info['mask_indices'] # 图片seq坐标下的
375
+ source_img_k[:, :, mask_indices, ...] = img_k
376
+ source_img_v[:, :, mask_indices, ...] = img_v
377
+
378
+ k = torch.cat((txt_k, source_img_k), dim=2)
379
+ v = torch.cat((txt_v, source_img_v), dim=2)
380
+ attn = attention(q, k, v, pe=pe, pe_q = info['pe_mask'])
381
+
382
+ # compute attention
383
+ # attn = attention(q, k, v, pe=pe)
384
+ # compute activation in mlp stream, cat again and run second linear layer
385
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
386
+ return x + mod.gate * output
flux/sampling.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Callable
3
+
4
+ import torch
5
+ from einops import rearrange, repeat
6
+ from torch import Tensor
7
+
8
+ from .model import Flux,Flux_kv
9
+ from .modules.conditioner import HFEmbedder
10
+ from tqdm import tqdm
11
+ from tqdm.contrib import tzip
12
+
13
+ def get_noise(
14
+ num_samples: int,
15
+ height: int,
16
+ width: int,
17
+ device: torch.device,
18
+ dtype: torch.dtype,
19
+ seed: int,
20
+ ):
21
+ return torch.randn(
22
+ num_samples,
23
+ 16,
24
+ # allow for packing
25
+ 2 * math.ceil(height / 16),
26
+ 2 * math.ceil(width / 16),
27
+ device=device,
28
+ dtype=dtype,
29
+ generator=torch.Generator(device=device).manual_seed(seed),
30
+ )
31
+
32
+
33
+ def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
34
+ bs, c, h, w = img.shape
35
+ if bs == 1 and not isinstance(prompt, str):
36
+ bs = len(prompt)
37
+
38
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
39
+ if img.shape[0] == 1 and bs > 1:
40
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
41
+
42
+ img_ids = torch.zeros(h // 2, w // 2, 3)
43
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
44
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
45
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
46
+
47
+ if isinstance(prompt, str):
48
+ prompt = [prompt]
49
+ txt = t5(prompt)
50
+ if txt.shape[0] == 1 and bs > 1:
51
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
52
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
53
+
54
+ vec = clip(prompt)
55
+ if vec.shape[0] == 1 and bs > 1:
56
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
57
+
58
+ return {
59
+ "img": img,
60
+ "img_ids": img_ids.to(img.device),
61
+ "txt": txt.to(img.device),
62
+ "txt_ids": txt_ids.to(img.device),
63
+ "vec": vec.to(img.device),
64
+ }
65
+
66
+ def prepare_flowedit(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, source_prompt: str | list[str],target_prompt) -> dict[str, Tensor]:
67
+ bs, c, h, w = img.shape
68
+ if bs == 1 and not isinstance(source_prompt, str):
69
+ bs = len(source_prompt)
70
+
71
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
72
+ if img.shape[0] == 1 and bs > 1:
73
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
74
+
75
+ img_ids = torch.zeros(h // 2, w // 2, 3)
76
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
77
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
78
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
79
+
80
+ # if isinstance(prompt, str):
81
+ # prompt = [prompt]
82
+ # txt = t5(prompt)
83
+ # if txt.shape[0] == 1 and bs > 1:
84
+ # txt = repeat(txt, "1 ... -> bs ...", bs=bs)
85
+ # txt_ids = torch.zeros(bs, txt.shape[1], 3)
86
+
87
+ # vec = clip(prompt)
88
+ # if vec.shape[0] == 1 and bs > 1:
89
+ # vec = repeat(vec, "1 ... -> bs ...", bs=bs)
90
+ if isinstance(source_prompt, str):
91
+ source_prompt = [source_prompt]
92
+ source_txt = t5(source_prompt)
93
+ if source_txt.shape[0] == 1 and bs > 1:
94
+ source_txt = repeat(source_txt, "1 ... -> bs ...", bs=bs)
95
+ source_txt_ids = torch.zeros(bs, source_txt.shape[1], 3)
96
+
97
+ source_vec = clip(target_prompt)
98
+ if source_vec.shape[0] == 1 and bs > 1:
99
+ source_vec = repeat(source_vec, "1 ... -> bs ...", bs=bs)
100
+
101
+ if isinstance(target_prompt, str):
102
+ target_prompt = [target_prompt]
103
+ target_txt = t5(target_prompt)
104
+ if target_txt.shape[0] == 1 and bs > 1:
105
+ target_txt = repeat(target_txt, "1 ... -> bs ...", bs=bs)
106
+ target_txt_ids = torch.zeros(bs, target_txt.shape[1], 3)
107
+
108
+ target_vec = clip(target_prompt)
109
+ if target_vec.shape[0] == 1 and bs > 1:
110
+ target_vec = repeat(target_vec, "1 ... -> bs ...", bs=bs)
111
+
112
+
113
+ return {
114
+ "img": img,
115
+ "img_ids": img_ids.to(img.device),
116
+ "source_txt": source_txt.to(img.device),
117
+ "source_txt_ids": source_txt_ids.to(img.device),
118
+ "source_vec": source_vec.to(img.device),
119
+ "target_txt": target_txt.to(img.device),
120
+ "target_txt_ids": target_txt_ids.to(img.device),
121
+ "target_vec": target_vec.to(img.device)
122
+ }
123
+
124
+ def time_shift(mu: float, sigma: float, t: Tensor):
125
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
126
+
127
+
128
+ def get_lin_function(
129
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
130
+ ) -> Callable[[float], float]:
131
+ m = (y2 - y1) / (x2 - x1)
132
+ b = y1 - m * x1
133
+ return lambda x: m * x + b
134
+
135
+
136
+ def get_schedule(
137
+ num_steps: int,
138
+ image_seq_len: int,
139
+ base_shift: float = 0.5,
140
+ max_shift: float = 1.15,
141
+ shift: bool = True,
142
+ ) -> list[float]:
143
+ # extra step for zero
144
+ timesteps = torch.linspace(1, 0, num_steps + 1)
145
+
146
+ # shifting the schedule to favor high timesteps for higher signal images
147
+ if shift:
148
+ # estimate mu based on linear estimation between two points
149
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
150
+ timesteps = time_shift(mu, 1.0, timesteps)
151
+
152
+ return timesteps.tolist()
153
+
154
+
155
+ def denoise(
156
+ model: Flux,
157
+ # model input
158
+ img: Tensor,
159
+ img_ids: Tensor,
160
+ txt: Tensor,
161
+ txt_ids: Tensor,
162
+ vec: Tensor,
163
+ # sampling parameters
164
+ timesteps: list[float],
165
+ guidance: float = 4.0,
166
+ ):
167
+ # this is ignored for schnell
168
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
169
+ for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
170
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
171
+ pred = model(
172
+ img=img,
173
+ img_ids=img_ids,
174
+ txt=txt,
175
+ txt_ids=txt_ids,
176
+ y=vec,
177
+ timesteps=t_vec,
178
+ guidance=guidance_vec,
179
+ )
180
+
181
+ img = img + (t_prev - t_curr) * pred
182
+
183
+ return img
184
+
185
+ def unpack(x: Tensor, height: int, width: int) -> Tensor:
186
+ return rearrange(
187
+ x,
188
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
189
+ h=math.ceil(height / 16),
190
+ w=math.ceil(width / 16),
191
+ ph=2,
192
+ pw=2,
193
+ )
194
+
195
+ def denoise_kv(
196
+ model: Flux_kv,
197
+ # model input
198
+ img: Tensor,
199
+ img_ids: Tensor,
200
+ txt: Tensor,
201
+ txt_ids: Tensor,
202
+ vec: Tensor,
203
+ # sampling parameters
204
+ timesteps: list[float],
205
+ inverse,
206
+ info,
207
+ guidance: float = 4.0
208
+ ):
209
+
210
+ if inverse:
211
+ timesteps = timesteps[::-1]
212
+
213
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
214
+
215
+ for i, (t_curr, t_prev) in enumerate(tzip(timesteps[:-1], timesteps[1:])):
216
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
217
+ info['t'] = t_prev if inverse else t_curr
218
+
219
+ if inverse:
220
+ img_name = str(info['t']) + '_' + 'img'
221
+ info['feature'][img_name] = img.cpu()
222
+ else:
223
+ img_name = str(info['t']) + '_' + 'img'
224
+ source_img = info['feature'][img_name].to(img.device)
225
+ img = source_img[:, info['mask_indices'],...] * (1 - info['mask'][:, info['mask_indices'],...]) + img * info['mask'][:, info['mask_indices'],...]
226
+ pred = model(
227
+ img=img,
228
+ img_ids=img_ids,
229
+ txt=txt,
230
+ txt_ids=txt_ids,
231
+ y=vec,
232
+ timesteps=t_vec,
233
+ guidance=guidance_vec,
234
+ info=info
235
+ )
236
+ img = img + (t_prev - t_curr) * pred
237
+ return img, info
238
+
239
+ def denoise_kv_inf(
240
+ model: Flux_kv,
241
+ # model input
242
+ img: Tensor,
243
+ img_ids: Tensor,
244
+ source_txt: Tensor,
245
+ source_txt_ids: Tensor,
246
+ source_vec: Tensor,
247
+ target_txt: Tensor,
248
+ target_txt_ids: Tensor,
249
+ target_vec: Tensor,
250
+ # sampling parameters
251
+ timesteps: list[float],
252
+ target_guidance: float = 4.0,
253
+ source_guidance: float = 4.0,
254
+ info: dict = {},
255
+ ):
256
+
257
+ target_guidance_vec = torch.full((img.shape[0],), target_guidance, device=img.device, dtype=img.dtype)
258
+ source_guidance_vec = torch.full((img.shape[0],), source_guidance, device=img.device, dtype=img.dtype)
259
+
260
+ mask_indices = info['mask_indices']
261
+ init_img = img.clone() # torch.Size([1, 4080, 64])
262
+ z_fe = img[:, mask_indices,...]
263
+
264
+ noise_list = []
265
+ for i in range(len(timesteps)):
266
+ noise = torch.randn(init_img.size(), dtype=init_img.dtype,
267
+ layout=init_img.layout, device=init_img.device,
268
+ generator=torch.Generator(device=init_img.device).manual_seed(0)) # 每次重新取噪声 根据t进行加噪
269
+ noise_list.append(noise)
270
+
271
+ for i, (t_curr, t_prev) in enumerate(tzip(timesteps[:-1], timesteps[1:])): # 从高到低
272
+
273
+ info['t'] = 'inf'
274
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
275
+
276
+ z_src = (1 - t_curr) * init_img + t_curr * noise_list[i]
277
+ z_tar = z_src[:, mask_indices,...] - init_img[:, mask_indices,...] + z_fe
278
+
279
+ info['inverse'] = True
280
+ info['feature'] = {} # 清空kv特征
281
+ v_src = model(
282
+ img=z_src,
283
+ img_ids=img_ids,
284
+ txt=source_txt,
285
+ txt_ids=source_txt_ids,
286
+ y=source_vec,
287
+ timesteps=t_vec,
288
+ guidance=source_guidance_vec,
289
+ info=info
290
+ )
291
+
292
+ info['inverse'] = False
293
+ v_tar = model(
294
+ img=z_tar,
295
+ img_ids=img_ids,
296
+ txt=target_txt,
297
+ txt_ids=target_txt_ids,
298
+ y=target_vec,
299
+ timesteps=t_vec,
300
+ guidance=target_guidance_vec,
301
+ info=info
302
+ )
303
+
304
+ v_fe = v_tar - v_src[:, mask_indices,...]
305
+ z_fe = z_fe + (t_prev - t_curr) * v_fe * info['mask'][:, mask_indices,...]
306
+ return z_fe, info
flux/util.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from huggingface_hub import hf_hub_download
7
+ from imwatermark import WatermarkEncoder
8
+ from safetensors.torch import load_file as load_sft
9
+
10
+ from flux.model import Flux, FluxParams
11
+ from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
12
+ from flux.modules.conditioner import HFEmbedder
13
+
14
+
15
+ @dataclass
16
+ class ModelSpec:
17
+ params: FluxParams
18
+ ae_params: AutoEncoderParams
19
+ ckpt_path: str | None
20
+ ae_path: str | None
21
+ repo_id: str | None
22
+ repo_flow: str | None
23
+ repo_ae: str | None
24
+
25
+ configs = {
26
+ "flux-dev": ModelSpec(
27
+ repo_id="black-forest-labs/FLUX.1-dev",
28
+ repo_flow="flux1-dev.safetensors",
29
+ repo_ae="ae.safetensors",
30
+ ckpt_path=os.getenv("FLUX_DEV"),
31
+ params=FluxParams(
32
+ in_channels=64,
33
+ vec_in_dim=768,
34
+ context_in_dim=4096,
35
+ hidden_size=3072,
36
+ mlp_ratio=4.0,
37
+ num_heads=24,
38
+ depth=19,
39
+ depth_single_blocks=38,
40
+ axes_dim=[16, 56, 56],
41
+ theta=10_000,
42
+ qkv_bias=True,
43
+ guidance_embed=True,
44
+ ),
45
+ ae_path=os.getenv("AE"),
46
+ ae_params=AutoEncoderParams(
47
+ resolution=256,
48
+ in_channels=3,
49
+ ch=128,
50
+ out_ch=3,
51
+ ch_mult=[1, 2, 4, 4],
52
+ num_res_blocks=2,
53
+ z_channels=16,
54
+ scale_factor=0.3611,
55
+ shift_factor=0.1159,
56
+ ),
57
+ ),
58
+ "flux-schnell": ModelSpec(
59
+ repo_id="black-forest-labs/FLUX.1-schnell",
60
+ repo_flow="flux1-schnell.safetensors",
61
+ repo_ae="ae.safetensors",
62
+ ckpt_path=os.getenv("FLUX_SCHNELL"),
63
+ params=FluxParams(
64
+ in_channels=64,
65
+ vec_in_dim=768,
66
+ context_in_dim=4096,
67
+ hidden_size=3072,
68
+ mlp_ratio=4.0,
69
+ num_heads=24,
70
+ depth=19,
71
+ depth_single_blocks=38,
72
+ axes_dim=[16, 56, 56],
73
+ theta=10_000,
74
+ qkv_bias=True,
75
+ guidance_embed=False,
76
+ ),
77
+ ae_path=os.getenv("AE"),
78
+ ae_params=AutoEncoderParams(
79
+ resolution=256,
80
+ in_channels=3,
81
+ ch=128,
82
+ out_ch=3,
83
+ ch_mult=[1, 2, 4, 4],
84
+ num_res_blocks=2,
85
+ z_channels=16,
86
+ scale_factor=0.3611,
87
+ shift_factor=0.1159,
88
+ ),
89
+ ),
90
+ }
91
+
92
+
93
+ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
94
+ if len(missing) > 0 and len(unexpected) > 0:
95
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
96
+ print("\n" + "-" * 79 + "\n")
97
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
98
+ elif len(missing) > 0:
99
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
100
+ elif len(unexpected) > 0:
101
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
102
+
103
+
104
+ def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True, flux_cls=Flux) -> Flux:
105
+ # Loading Flux
106
+ print("Init model")
107
+
108
+ ckpt_path = configs[name].ckpt_path
109
+ if (
110
+ ckpt_path is None
111
+ and configs[name].repo_id is not None
112
+ and configs[name].repo_flow is not None
113
+ and hf_download
114
+ ):
115
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
116
+
117
+ with torch.device("meta" if ckpt_path is not None else device):
118
+ model = flux_cls(configs[name].params).to(torch.bfloat16)
119
+
120
+ if ckpt_path is not None:
121
+ print("Loading checkpoint")
122
+ # load_sft doesn't support torch.device
123
+ sd = load_sft(ckpt_path, device=str(device))
124
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
125
+ print_load_warning(missing, unexpected)
126
+ return model
127
+
128
+
129
+ def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
130
+ # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
131
+ return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, is_clip=False, torch_dtype=torch.bfloat16).to(device)
132
+
133
+
134
+ def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
135
+ return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, is_clip=True, torch_dtype=torch.bfloat16).to(device)
136
+
137
+
138
+ def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
139
+ ckpt_path = configs[name].ae_path
140
+ if (
141
+ ckpt_path is None
142
+ and configs[name].repo_id is not None
143
+ and configs[name].repo_ae is not None
144
+ and hf_download
145
+ ):
146
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
147
+
148
+ # Loading the autoencoder
149
+ print("Init AE")
150
+ with torch.device("meta" if ckpt_path is not None else device):
151
+ ae = AutoEncoder(configs[name].ae_params)
152
+
153
+ if ckpt_path is not None:
154
+ sd = load_sft(ckpt_path, device=str(device))
155
+ missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
156
+ print_load_warning(missing, unexpected)
157
+ return ae
158
+
159
+
160
+ class WatermarkEmbedder:
161
+ def __init__(self, watermark):
162
+ self.watermark = watermark
163
+ self.num_bits = len(WATERMARK_BITS)
164
+ self.encoder = WatermarkEncoder()
165
+ self.encoder.set_watermark("bits", self.watermark)
166
+
167
+ def __call__(self, image: torch.Tensor) -> torch.Tensor:
168
+ """
169
+ Adds a predefined watermark to the input image
170
+
171
+ Args:
172
+ image: ([N,] B, RGB, H, W) in range [-1, 1]
173
+
174
+ Returns:
175
+ same as input but watermarked
176
+ """
177
+ image = 0.5 * image + 0.5
178
+ squeeze = len(image.shape) == 4
179
+ if squeeze:
180
+ image = image[None, ...]
181
+ n = image.shape[0]
182
+ image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
183
+ # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
184
+ # watermarking libary expects input as cv2 BGR format
185
+ for k in range(image_np.shape[0]):
186
+ image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
187
+ image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
188
+ image.device
189
+ )
190
+ image = torch.clamp(image / 255, min=0.0, max=1.0)
191
+ if squeeze:
192
+ image = image[0]
193
+ image = 2 * image - 1
194
+ return image
195
+
196
+
197
+ # A fixed 48-bit message that was chosen at random
198
+ WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
199
+ # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
200
+ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
201
+ embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
models/__pycache__/kv_edit.cpython-310.pyc ADDED
Binary file (6.89 kB). View file
 
models/kv_edit.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ from einops import rearrange,repeat
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import Tensor
7
+ from typing import List
8
+
9
+ from flux.sampling import get_schedule, unpack,denoise_kv,denoise_kv_inf
10
+ from flux.util import load_flow_model
11
+ from flux.model import Flux_kv
12
+
13
+ @dataclass
14
+ class SamplingOptions:
15
+ source_prompt: str = ''
16
+ target_prompt: str = ''
17
+ # prompt: str
18
+ width: int = 1366
19
+ height: int = 768
20
+ inversion_num_steps: int = 0
21
+ denoise_num_steps: int = 0
22
+ skip_step: int = 0
23
+ inversion_guidance: float = 1.0
24
+ denoise_guidance: float = 1.0
25
+ seed: int = 42
26
+ re_init: bool = False
27
+ attn_mask: bool = False
28
+
29
+ class only_Flux(torch.nn.Module): # 仅包括初始化函数
30
+ def __init__(self, device,name='flux-dev'):
31
+ self.device = device
32
+ self.name = name
33
+ super().__init__()
34
+ self.model = load_flow_model(self.name, device=self.device,flux_cls=Flux_kv)
35
+
36
+ def create_attention_mask(self,seq_len, mask_indices, text_len=512, device='cuda'):
37
+ """
38
+ 创建自定义的注意力掩码。
39
+
40
+ Args:
41
+ seq_len (int): 序列长度。
42
+ mask_indices (List[int]): 图像令牌中掩码区域的索引。
43
+ text_len (int): 文本令牌的长度,默认 512。
44
+ device (str): 设备类型,如 'cuda' 或 'cpu'。
45
+
46
+ Returns:
47
+ torch.Tensor: 形状为 (seq_len, seq_len) 的注意力掩码。
48
+ """
49
+ # 初始化掩码为全 False
50
+ attention_mask = torch.zeros(seq_len, seq_len, dtype=torch.bool, device=device)
51
+
52
+ # 文本令牌索引
53
+ text_indices = torch.arange(0, text_len, device=device)
54
+
55
+ # 掩码区域令牌索引
56
+ mask_token_indices = torch.tensor([idx + text_len for idx in mask_indices], device=device)
57
+
58
+ # 背景区域令牌索引
59
+ all_indices = torch.arange(text_len, seq_len, device=device)
60
+ background_token_indices = torch.tensor([idx for idx in all_indices if idx not in mask_token_indices])
61
+
62
+ # 设置文本查询可以关注所有键
63
+ attention_mask[text_indices.unsqueeze(1).expand(-1, seq_len)] = True
64
+ attention_mask[text_indices.unsqueeze(1), text_indices] = True# 关注文本
65
+ attention_mask[text_indices.unsqueeze(1), background_token_indices] = True # 关注背景
66
+
67
+
68
+ # attention_mask[mask_token_indices.unsqueeze(1), background_token_indices] = True # 关注背景
69
+ attention_mask[mask_token_indices.unsqueeze(1), text_indices] = True # 关注文本
70
+ attention_mask[mask_token_indices.unsqueeze(1), mask_token_indices] = True # 关注掩码区域
71
+
72
+
73
+ # attention_mask[background_token_indices.unsqueeze(1).expand(-1, seq_len), :] = False
74
+ attention_mask[background_token_indices.unsqueeze(1), mask_token_indices] = True # 关注掩码
75
+ attention_mask[background_token_indices.unsqueeze(1), text_indices] = True # 关注文本
76
+ attention_mask[background_token_indices.unsqueeze(1), background_token_indices] = True # 关注背景区域
77
+
78
+ return attention_mask.unsqueeze(0)
79
+
80
+ class Flux_kv_edit_inf(only_Flux):
81
+ def __init__(self, device,name):
82
+ super().__init__(device,name)
83
+
84
+ @torch.inference_mode()
85
+ def forward(self,inp,inp_target,mask:Tensor,opts):
86
+ #############根据mask生成token序列上的索引试试#######################
87
+ info = {}
88
+ info['feature'] = {}
89
+ bs, L, d = inp["img"].shape
90
+ h = opts.height // 8
91
+ w = opts.width // 8
92
+ mask = F.interpolate(mask, size=(h,w), mode='bilinear', align_corners=False)
93
+ mask[mask > 0] = 1
94
+
95
+ mask = repeat(mask, 'b c h w -> b (repeat c) h w', repeat=16)
96
+ # mask = F.max_pool2d(mask, kernel_size=3, stride=1, padding=1)
97
+ # mask = mask.flatten().to(self.device[1])
98
+ mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
99
+ info['mask'] = mask
100
+ bool_mask = (mask.sum(dim=2) > 0.5)
101
+ info['mask_indices'] = torch.nonzero(bool_mask)[:,1] # 使用花式索引 即 数字tensor索引tensor 这个是基于图像的 在seq中需要加512
102
+ #单独分离inversion
103
+ if opts.attn_mask and (~bool_mask).any(): # mask有一个false就进行attn mask 全true就none
104
+ attention_mask = self.create_attention_mask(L+512, info['mask_indices'], device=self.device)
105
+ else:
106
+ attention_mask = None
107
+ info['attention_mask'] = attention_mask
108
+
109
+ denoise_timesteps = get_schedule(opts.denoise_num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
110
+ # denoise_timesteps = get_schedule(opts.denoise_num_steps, inp_target["img"].shape[1], shift=False)
111
+ denoise_timesteps = denoise_timesteps[opts.skip_step:]
112
+
113
+ z0 = inp["img"]
114
+
115
+ with torch.no_grad():
116
+ info['inject'] = True
117
+ z_fe, info = denoise_kv_inf(self.model, img=inp["img"], img_ids=inp['img_ids'],
118
+ source_txt=inp['txt'], source_txt_ids=inp['txt_ids'], source_vec=inp['vec'],
119
+ target_txt=inp_target['txt'], target_txt_ids=inp_target['txt_ids'], target_vec=inp_target['vec'],
120
+ timesteps=denoise_timesteps, source_guidance=opts.inversion_guidance, target_guidance=opts.denoise_guidance,
121
+ info=info)
122
+ mask_indices = info['mask_indices'] # 图片seq坐标下的
123
+ # x是根据索引取出来的 再放回去
124
+ z0[:, mask_indices,...] = z_fe
125
+
126
+ # decode latents to pixel space
127
+ z0 = unpack(z0.float(), opts.height, opts.width)
128
+ del info
129
+ return z0
130
+
131
+ class Flux_kv_edit(only_Flux):
132
+ def __init__(self, device,name):
133
+ super().__init__(device,name)
134
+
135
+ @torch.inference_mode()
136
+ def forward(self,inp,inp_target,mask:Tensor,opts):
137
+ z0,zt,info = self.inverse(inp,mask,opts)
138
+ z0 = self.denoise(z0,zt,inp_target,mask,opts,info)
139
+ return z0
140
+ @torch.inference_mode()
141
+ def inverse(self,inp,mask,opts):
142
+ info = {}
143
+ info['feature'] = {}
144
+ bs, L, d = inp["img"].shape
145
+ h = opts.height // 8
146
+ w = opts.width // 8
147
+ # mask = F.interpolate(mask, size=(h,w), mode='nearest')
148
+
149
+ if opts.attn_mask:
150
+ mask = F.interpolate(mask, size=(h,w), mode='bilinear', align_corners=False)
151
+ mask[mask > 0] = 1
152
+
153
+ mask = repeat(mask, 'b c h w -> b (repeat c) h w', repeat=16)
154
+ # mask = F.max_pool2d(mask, kernel_size=3, stride=1, padding=1)
155
+ # mask = mask.flatten().to(self.device[1])
156
+ mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
157
+ bool_mask = (mask.sum(dim=2) > 0.5)
158
+ mask_indices = torch.nonzero(bool_mask)[:,1] # 使用花式索引 即 数字tensor索引tensor 这个是基于图像的 在seq中需要加512
159
+
160
+ #单独分离inversion
161
+ assert not (~bool_mask).all(), "mask is all false"
162
+ assert not (bool_mask).all(), "mask is all true"
163
+ attention_mask = self.create_attention_mask(L+512, mask_indices, device=mask.device)
164
+ info['attention_mask'] = attention_mask
165
+
166
+
167
+ denoise_timesteps = get_schedule(opts.denoise_num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
168
+ denoise_timesteps = denoise_timesteps[opts.skip_step:]
169
+
170
+ # 加噪过程
171
+ z0 = inp["img"].clone()
172
+ info['inverse'] = True
173
+ zt, info = denoise_kv(self.model, **inp, timesteps=denoise_timesteps, guidance=opts.inversion_guidance, inverse=True, info=info)
174
+ return z0,zt,info
175
+
176
+ @torch.inference_mode()
177
+ def denoise(self,z0,zt,inp_target,mask:Tensor,opts,info):
178
+
179
+ h = opts.height // 8
180
+ w = opts.width // 8
181
+
182
+ mask = F.interpolate(mask, size=(h,w), mode='bilinear', align_corners=False)
183
+ mask[mask > 0] = 1
184
+
185
+ mask = repeat(mask, 'b c h w -> b (repeat c) h w', repeat=16)
186
+
187
+ mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
188
+ info['mask'] = mask
189
+ bool_mask = (mask.sum(dim=2) > 0.5)
190
+ info['mask_indices'] = torch.nonzero(bool_mask)[:,1] # 使用花式索引 即 数字tensor索引tensor 这个是基于图像的 在seq中需要加512
191
+
192
+ denoise_timesteps = get_schedule(opts.denoise_num_steps, inp_target["img"].shape[1], shift=(self.name != "flux-schnell"))
193
+ denoise_timesteps = denoise_timesteps[opts.skip_step:]
194
+ # 重建的时候不需要全部token z这里需要根据indice拿出来
195
+ mask_indices = info['mask_indices'] # 图片seq坐标下的
196
+ if opts.re_init:
197
+ noise = torch.randn_like(zt)
198
+ t = denoise_timesteps[0]
199
+ zt_noise = z0 *(1 - t) + noise * t
200
+ inp_target["img"] = zt_noise[:, mask_indices,...]
201
+ else:
202
+ inp_target["img"] = zt[:, mask_indices,...]
203
+
204
+ info['inverse'] = False
205
+ x, _ = denoise_kv(self.model, **inp_target, timesteps=denoise_timesteps, guidance=opts.denoise_guidance, inverse=False, info=info)
206
+ # x是根据索引取出来的 再放回去
207
+ z0[:, mask_indices,...] = z0[:, mask_indices,...] * (1 - info['mask'][:, mask_indices,...]) + x * info['mask'][:, mask_indices,...]
208
+ # x = inp['img'].clone()
209
+
210
+ # decode latents to pixel space
211
+ z0 = unpack(z0.float(), opts.height, opts.width)
212
+ del info
213
+ return z0
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ einops
3
+ accelerate==0.34.2
4
+ einops==0.8.0
5
+ transformers==4.41.2
6
+ huggingface-hub==0.24.6
7
+ datasets
8
+ omegaconf
9
+ diffusers
10
+ sentencepiece
11
+ opencv-python
12
+ matplotlib
13
+ onnxruntime
14
+ torchvision
15
+ timm
16
+ invisible-watermark
17
+ fire
18
+ tqdm