firzaelbuho commited on
Commit
6c4f135
1 Parent(s): f006919

Upload 364 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. __pycache__/cuda_malloc.cpython-310.pyc +0 -0
  2. __pycache__/execution.cpython-310.pyc +0 -0
  3. __pycache__/folder_paths.cpython-310.pyc +0 -0
  4. __pycache__/latent_preview.cpython-310.pyc +0 -0
  5. __pycache__/nodes.cpython-310.pyc +0 -0
  6. __pycache__/server.cpython-310.pyc +0 -0
  7. app/__pycache__/app_settings.cpython-310.pyc +0 -0
  8. app/__pycache__/user_manager.cpython-310.pyc +0 -0
  9. app/app_settings.py +54 -0
  10. app/user_manager.py +140 -0
  11. comfy/__pycache__/checkpoint_pickle.cpython-310.pyc +0 -0
  12. comfy/__pycache__/cli_args.cpython-310.pyc +0 -0
  13. comfy/__pycache__/clip_model.cpython-310.pyc +0 -0
  14. comfy/__pycache__/clip_vision.cpython-310.pyc +0 -0
  15. comfy/__pycache__/conds.cpython-310.pyc +0 -0
  16. comfy/__pycache__/controlnet.cpython-310.pyc +0 -0
  17. comfy/__pycache__/diffusers_convert.cpython-310.pyc +0 -0
  18. comfy/__pycache__/diffusers_load.cpython-310.pyc +0 -0
  19. comfy/__pycache__/gligen.cpython-310.pyc +0 -0
  20. comfy/__pycache__/latent_formats.cpython-310.pyc +0 -0
  21. comfy/__pycache__/lora.cpython-310.pyc +0 -0
  22. comfy/__pycache__/model_base.cpython-310.pyc +0 -0
  23. comfy/__pycache__/model_detection.cpython-310.pyc +0 -0
  24. comfy/__pycache__/model_management.cpython-310.pyc +0 -0
  25. comfy/__pycache__/model_patcher.cpython-310.pyc +0 -0
  26. comfy/__pycache__/model_sampling.cpython-310.pyc +0 -0
  27. comfy/__pycache__/ops.cpython-310.pyc +0 -0
  28. comfy/__pycache__/options.cpython-310.pyc +0 -0
  29. comfy/__pycache__/sample.cpython-310.pyc +0 -0
  30. comfy/__pycache__/samplers.cpython-310.pyc +0 -0
  31. comfy/__pycache__/sd.cpython-310.pyc +0 -0
  32. comfy/__pycache__/sd1_clip.cpython-310.pyc +0 -0
  33. comfy/__pycache__/sd2_clip.cpython-310.pyc +0 -0
  34. comfy/__pycache__/sdxl_clip.cpython-310.pyc +0 -0
  35. comfy/__pycache__/supported_models.cpython-310.pyc +0 -0
  36. comfy/__pycache__/supported_models_base.cpython-310.pyc +0 -0
  37. comfy/__pycache__/utils.cpython-310.pyc +0 -0
  38. comfy/checkpoint_pickle.py +13 -0
  39. comfy/cldm/__pycache__/cldm.cpython-310.pyc +0 -0
  40. comfy/cldm/cldm.py +312 -0
  41. comfy/cli_args.py +126 -0
  42. comfy/clip_config_bigg.json +23 -0
  43. comfy/clip_model.py +188 -0
  44. comfy/clip_vision.py +116 -0
  45. comfy/clip_vision_config_g.json +18 -0
  46. comfy/clip_vision_config_h.json +18 -0
  47. comfy/clip_vision_config_vitl.json +18 -0
  48. comfy/conds.py +78 -0
  49. comfy/controlnet.py +525 -0
  50. comfy/diffusers_convert.py +261 -0
__pycache__/cuda_malloc.cpython-310.pyc ADDED
Binary file (2.67 kB). View file
 
__pycache__/execution.cpython-310.pyc ADDED
Binary file (18.6 kB). View file
 
__pycache__/folder_paths.cpython-310.pyc ADDED
Binary file (7.45 kB). View file
 
__pycache__/latent_preview.cpython-310.pyc ADDED
Binary file (3.66 kB). View file
 
__pycache__/nodes.cpython-310.pyc ADDED
Binary file (59.9 kB). View file
 
__pycache__/server.cpython-310.pyc ADDED
Binary file (20.4 kB). View file
 
app/__pycache__/app_settings.cpython-310.pyc ADDED
Binary file (2.41 kB). View file
 
app/__pycache__/user_manager.cpython-310.pyc ADDED
Binary file (4.12 kB). View file
 
app/app_settings.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from aiohttp import web
4
+
5
+
6
+ class AppSettings():
7
+ def __init__(self, user_manager):
8
+ self.user_manager = user_manager
9
+
10
+ def get_settings(self, request):
11
+ file = self.user_manager.get_request_user_filepath(
12
+ request, "comfy.settings.json")
13
+ if os.path.isfile(file):
14
+ with open(file) as f:
15
+ return json.load(f)
16
+ else:
17
+ return {}
18
+
19
+ def save_settings(self, request, settings):
20
+ file = self.user_manager.get_request_user_filepath(
21
+ request, "comfy.settings.json")
22
+ with open(file, "w") as f:
23
+ f.write(json.dumps(settings, indent=4))
24
+
25
+ def add_routes(self, routes):
26
+ @routes.get("/settings")
27
+ async def get_settings(request):
28
+ return web.json_response(self.get_settings(request))
29
+
30
+ @routes.get("/settings/{id}")
31
+ async def get_setting(request):
32
+ value = None
33
+ settings = self.get_settings(request)
34
+ setting_id = request.match_info.get("id", None)
35
+ if setting_id and setting_id in settings:
36
+ value = settings[setting_id]
37
+ return web.json_response(value)
38
+
39
+ @routes.post("/settings")
40
+ async def post_settings(request):
41
+ settings = self.get_settings(request)
42
+ new_settings = await request.json()
43
+ self.save_settings(request, {**settings, **new_settings})
44
+ return web.Response(status=200)
45
+
46
+ @routes.post("/settings/{id}")
47
+ async def post_setting(request):
48
+ setting_id = request.match_info.get("id", None)
49
+ if not setting_id:
50
+ return web.Response(status=400)
51
+ settings = self.get_settings(request)
52
+ settings[setting_id] = await request.json()
53
+ self.save_settings(request, settings)
54
+ return web.Response(status=200)
app/user_manager.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ import uuid
5
+ from aiohttp import web
6
+ from comfy.cli_args import args
7
+ from folder_paths import user_directory
8
+ from .app_settings import AppSettings
9
+
10
+ default_user = "default"
11
+ users_file = os.path.join(user_directory, "users.json")
12
+
13
+
14
+ class UserManager():
15
+ def __init__(self):
16
+ global user_directory
17
+
18
+ self.settings = AppSettings(self)
19
+ if not os.path.exists(user_directory):
20
+ os.mkdir(user_directory)
21
+ if not args.multi_user:
22
+ print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
23
+ print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
24
+
25
+ if args.multi_user:
26
+ if os.path.isfile(users_file):
27
+ with open(users_file) as f:
28
+ self.users = json.load(f)
29
+ else:
30
+ self.users = {}
31
+ else:
32
+ self.users = {"default": "default"}
33
+
34
+ def get_request_user_id(self, request):
35
+ user = "default"
36
+ if args.multi_user and "comfy-user" in request.headers:
37
+ user = request.headers["comfy-user"]
38
+
39
+ if user not in self.users:
40
+ raise KeyError("Unknown user: " + user)
41
+
42
+ return user
43
+
44
+ def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
45
+ global user_directory
46
+
47
+ if type == "userdata":
48
+ root_dir = user_directory
49
+ else:
50
+ raise KeyError("Unknown filepath type:" + type)
51
+
52
+ user = self.get_request_user_id(request)
53
+ path = user_root = os.path.abspath(os.path.join(root_dir, user))
54
+
55
+ # prevent leaving /{type}
56
+ if os.path.commonpath((root_dir, user_root)) != root_dir:
57
+ return None
58
+
59
+ parent = user_root
60
+
61
+ if file is not None:
62
+ # prevent leaving /{type}/{user}
63
+ path = os.path.abspath(os.path.join(user_root, file))
64
+ if os.path.commonpath((user_root, path)) != user_root:
65
+ return None
66
+
67
+ if create_dir and not os.path.exists(parent):
68
+ os.mkdir(parent)
69
+
70
+ return path
71
+
72
+ def add_user(self, name):
73
+ name = name.strip()
74
+ if not name:
75
+ raise ValueError("username not provided")
76
+ user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
77
+ user_id = user_id + "_" + str(uuid.uuid4())
78
+
79
+ self.users[user_id] = name
80
+
81
+ global users_file
82
+ with open(users_file, "w") as f:
83
+ json.dump(self.users, f)
84
+
85
+ return user_id
86
+
87
+ def add_routes(self, routes):
88
+ self.settings.add_routes(routes)
89
+
90
+ @routes.get("/users")
91
+ async def get_users(request):
92
+ if args.multi_user:
93
+ return web.json_response({"storage": "server", "users": self.users})
94
+ else:
95
+ user_dir = self.get_request_user_filepath(request, None, create_dir=False)
96
+ return web.json_response({
97
+ "storage": "server",
98
+ "migrated": os.path.exists(user_dir)
99
+ })
100
+
101
+ @routes.post("/users")
102
+ async def post_users(request):
103
+ body = await request.json()
104
+ username = body["username"]
105
+ if username in self.users.values():
106
+ return web.json_response({"error": "Duplicate username."}, status=400)
107
+
108
+ user_id = self.add_user(username)
109
+ return web.json_response(user_id)
110
+
111
+ @routes.get("/userdata/{file}")
112
+ async def getuserdata(request):
113
+ file = request.match_info.get("file", None)
114
+ if not file:
115
+ return web.Response(status=400)
116
+
117
+ path = self.get_request_user_filepath(request, file)
118
+ if not path:
119
+ return web.Response(status=403)
120
+
121
+ if not os.path.exists(path):
122
+ return web.Response(status=404)
123
+
124
+ return web.FileResponse(path)
125
+
126
+ @routes.post("/userdata/{file}")
127
+ async def post_userdata(request):
128
+ file = request.match_info.get("file", None)
129
+ if not file:
130
+ return web.Response(status=400)
131
+
132
+ path = self.get_request_user_filepath(request, file)
133
+ if not path:
134
+ return web.Response(status=403)
135
+
136
+ body = await request.read()
137
+ with open(path, "wb") as f:
138
+ f.write(body)
139
+
140
+ return web.Response(status=200)
comfy/__pycache__/checkpoint_pickle.cpython-310.pyc ADDED
Binary file (716 Bytes). View file
 
comfy/__pycache__/cli_args.cpython-310.pyc ADDED
Binary file (6.7 kB). View file
 
comfy/__pycache__/clip_model.cpython-310.pyc ADDED
Binary file (8.5 kB). View file
 
comfy/__pycache__/clip_vision.cpython-310.pyc ADDED
Binary file (5.16 kB). View file
 
comfy/__pycache__/conds.cpython-310.pyc ADDED
Binary file (3.26 kB). View file
 
comfy/__pycache__/controlnet.cpython-310.pyc ADDED
Binary file (16.1 kB). View file
 
comfy/__pycache__/diffusers_convert.cpython-310.pyc ADDED
Binary file (6.69 kB). View file
 
comfy/__pycache__/diffusers_load.cpython-310.pyc ADDED
Binary file (1.31 kB). View file
 
comfy/__pycache__/gligen.cpython-310.pyc ADDED
Binary file (10.3 kB). View file
 
comfy/__pycache__/latent_formats.cpython-310.pyc ADDED
Binary file (2.15 kB). View file
 
comfy/__pycache__/lora.cpython-310.pyc ADDED
Binary file (5 kB). View file
 
comfy/__pycache__/model_base.cpython-310.pyc ADDED
Binary file (16.3 kB). View file
 
comfy/__pycache__/model_detection.cpython-310.pyc ADDED
Binary file (10 kB). View file
 
comfy/__pycache__/model_management.cpython-310.pyc ADDED
Binary file (18.8 kB). View file
 
comfy/__pycache__/model_patcher.cpython-310.pyc ADDED
Binary file (10.7 kB). View file
 
comfy/__pycache__/model_sampling.cpython-310.pyc ADDED
Binary file (7.26 kB). View file
 
comfy/__pycache__/ops.cpython-310.pyc ADDED
Binary file (6.65 kB). View file
 
comfy/__pycache__/options.cpython-310.pyc ADDED
Binary file (286 Bytes). View file
 
comfy/__pycache__/sample.cpython-310.pyc ADDED
Binary file (4.63 kB). View file
 
comfy/__pycache__/samplers.cpython-310.pyc ADDED
Binary file (19.9 kB). View file
 
comfy/__pycache__/sd.cpython-310.pyc ADDED
Binary file (19.4 kB). View file
 
comfy/__pycache__/sd1_clip.cpython-310.pyc ADDED
Binary file (16.1 kB). View file
 
comfy/__pycache__/sd2_clip.cpython-310.pyc ADDED
Binary file (2.01 kB). View file
 
comfy/__pycache__/sdxl_clip.cpython-310.pyc ADDED
Binary file (5.43 kB). View file
 
comfy/__pycache__/supported_models.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
comfy/__pycache__/supported_models_base.cpython-310.pyc ADDED
Binary file (3.59 kB). View file
 
comfy/__pycache__/utils.cpython-310.pyc ADDED
Binary file (15.2 kB). View file
 
comfy/checkpoint_pickle.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ load = pickle.load
4
+
5
+ class Empty:
6
+ pass
7
+
8
+ class Unpickler(pickle.Unpickler):
9
+ def find_class(self, module, name):
10
+ #TODO: safe unpickle
11
+ if module.startswith("pytorch_lightning"):
12
+ return Empty
13
+ return super().find_class(module, name)
comfy/cldm/__pycache__/cldm.cpython-310.pyc ADDED
Binary file (6.07 kB). View file
 
comfy/cldm/cldm.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #taken from: https://github.com/lllyasviel/ControlNet
2
+ #and modified
3
+
4
+ import torch
5
+ import torch as th
6
+ import torch.nn as nn
7
+
8
+ from ..ldm.modules.diffusionmodules.util import (
9
+ zero_module,
10
+ timestep_embedding,
11
+ )
12
+
13
+ from ..ldm.modules.attention import SpatialTransformer
14
+ from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
15
+ from ..ldm.util import exists
16
+ import comfy.ops
17
+
18
+ class ControlledUnetModel(UNetModel):
19
+ #implemented in the ldm unet
20
+ pass
21
+
22
+ class ControlNet(nn.Module):
23
+ def __init__(
24
+ self,
25
+ image_size,
26
+ in_channels,
27
+ model_channels,
28
+ hint_channels,
29
+ num_res_blocks,
30
+ dropout=0,
31
+ channel_mult=(1, 2, 4, 8),
32
+ conv_resample=True,
33
+ dims=2,
34
+ num_classes=None,
35
+ use_checkpoint=False,
36
+ dtype=torch.float32,
37
+ num_heads=-1,
38
+ num_head_channels=-1,
39
+ num_heads_upsample=-1,
40
+ use_scale_shift_norm=False,
41
+ resblock_updown=False,
42
+ use_new_attention_order=False,
43
+ use_spatial_transformer=False, # custom transformer support
44
+ transformer_depth=1, # custom transformer support
45
+ context_dim=None, # custom transformer support
46
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
47
+ legacy=True,
48
+ disable_self_attentions=None,
49
+ num_attention_blocks=None,
50
+ disable_middle_self_attn=False,
51
+ use_linear_in_transformer=False,
52
+ adm_in_channels=None,
53
+ transformer_depth_middle=None,
54
+ transformer_depth_output=None,
55
+ device=None,
56
+ operations=comfy.ops.disable_weight_init,
57
+ **kwargs,
58
+ ):
59
+ super().__init__()
60
+ assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
61
+ if use_spatial_transformer:
62
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
63
+
64
+ if context_dim is not None:
65
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
66
+ # from omegaconf.listconfig import ListConfig
67
+ # if type(context_dim) == ListConfig:
68
+ # context_dim = list(context_dim)
69
+
70
+ if num_heads_upsample == -1:
71
+ num_heads_upsample = num_heads
72
+
73
+ if num_heads == -1:
74
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
75
+
76
+ if num_head_channels == -1:
77
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
78
+
79
+ self.dims = dims
80
+ self.image_size = image_size
81
+ self.in_channels = in_channels
82
+ self.model_channels = model_channels
83
+
84
+ if isinstance(num_res_blocks, int):
85
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
86
+ else:
87
+ if len(num_res_blocks) != len(channel_mult):
88
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
89
+ "as a list/tuple (per-level) with the same length as channel_mult")
90
+ self.num_res_blocks = num_res_blocks
91
+
92
+ if disable_self_attentions is not None:
93
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
94
+ assert len(disable_self_attentions) == len(channel_mult)
95
+ if num_attention_blocks is not None:
96
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
97
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
98
+
99
+ transformer_depth = transformer_depth[:]
100
+
101
+ self.dropout = dropout
102
+ self.channel_mult = channel_mult
103
+ self.conv_resample = conv_resample
104
+ self.num_classes = num_classes
105
+ self.use_checkpoint = use_checkpoint
106
+ self.dtype = dtype
107
+ self.num_heads = num_heads
108
+ self.num_head_channels = num_head_channels
109
+ self.num_heads_upsample = num_heads_upsample
110
+ self.predict_codebook_ids = n_embed is not None
111
+
112
+ time_embed_dim = model_channels * 4
113
+ self.time_embed = nn.Sequential(
114
+ operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
115
+ nn.SiLU(),
116
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
117
+ )
118
+
119
+ if self.num_classes is not None:
120
+ if isinstance(self.num_classes, int):
121
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
122
+ elif self.num_classes == "continuous":
123
+ print("setting up linear c_adm embedding layer")
124
+ self.label_emb = nn.Linear(1, time_embed_dim)
125
+ elif self.num_classes == "sequential":
126
+ assert adm_in_channels is not None
127
+ self.label_emb = nn.Sequential(
128
+ nn.Sequential(
129
+ operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
130
+ nn.SiLU(),
131
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
132
+ )
133
+ )
134
+ else:
135
+ raise ValueError()
136
+
137
+ self.input_blocks = nn.ModuleList(
138
+ [
139
+ TimestepEmbedSequential(
140
+ operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
141
+ )
142
+ ]
143
+ )
144
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
145
+
146
+ self.input_hint_block = TimestepEmbedSequential(
147
+ operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
148
+ nn.SiLU(),
149
+ operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
150
+ nn.SiLU(),
151
+ operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
152
+ nn.SiLU(),
153
+ operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
154
+ nn.SiLU(),
155
+ operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
156
+ nn.SiLU(),
157
+ operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
158
+ nn.SiLU(),
159
+ operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
160
+ nn.SiLU(),
161
+ operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
162
+ )
163
+
164
+ self._feature_size = model_channels
165
+ input_block_chans = [model_channels]
166
+ ch = model_channels
167
+ ds = 1
168
+ for level, mult in enumerate(channel_mult):
169
+ for nr in range(self.num_res_blocks[level]):
170
+ layers = [
171
+ ResBlock(
172
+ ch,
173
+ time_embed_dim,
174
+ dropout,
175
+ out_channels=mult * model_channels,
176
+ dims=dims,
177
+ use_checkpoint=use_checkpoint,
178
+ use_scale_shift_norm=use_scale_shift_norm,
179
+ dtype=self.dtype,
180
+ device=device,
181
+ operations=operations,
182
+ )
183
+ ]
184
+ ch = mult * model_channels
185
+ num_transformers = transformer_depth.pop(0)
186
+ if num_transformers > 0:
187
+ if num_head_channels == -1:
188
+ dim_head = ch // num_heads
189
+ else:
190
+ num_heads = ch // num_head_channels
191
+ dim_head = num_head_channels
192
+ if legacy:
193
+ #num_heads = 1
194
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
195
+ if exists(disable_self_attentions):
196
+ disabled_sa = disable_self_attentions[level]
197
+ else:
198
+ disabled_sa = False
199
+
200
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
201
+ layers.append(
202
+ SpatialTransformer(
203
+ ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
204
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
205
+ use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
206
+ )
207
+ )
208
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
209
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
210
+ self._feature_size += ch
211
+ input_block_chans.append(ch)
212
+ if level != len(channel_mult) - 1:
213
+ out_ch = ch
214
+ self.input_blocks.append(
215
+ TimestepEmbedSequential(
216
+ ResBlock(
217
+ ch,
218
+ time_embed_dim,
219
+ dropout,
220
+ out_channels=out_ch,
221
+ dims=dims,
222
+ use_checkpoint=use_checkpoint,
223
+ use_scale_shift_norm=use_scale_shift_norm,
224
+ down=True,
225
+ dtype=self.dtype,
226
+ device=device,
227
+ operations=operations
228
+ )
229
+ if resblock_updown
230
+ else Downsample(
231
+ ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
232
+ )
233
+ )
234
+ )
235
+ ch = out_ch
236
+ input_block_chans.append(ch)
237
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
238
+ ds *= 2
239
+ self._feature_size += ch
240
+
241
+ if num_head_channels == -1:
242
+ dim_head = ch // num_heads
243
+ else:
244
+ num_heads = ch // num_head_channels
245
+ dim_head = num_head_channels
246
+ if legacy:
247
+ #num_heads = 1
248
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
249
+ mid_block = [
250
+ ResBlock(
251
+ ch,
252
+ time_embed_dim,
253
+ dropout,
254
+ dims=dims,
255
+ use_checkpoint=use_checkpoint,
256
+ use_scale_shift_norm=use_scale_shift_norm,
257
+ dtype=self.dtype,
258
+ device=device,
259
+ operations=operations
260
+ )]
261
+ if transformer_depth_middle >= 0:
262
+ mid_block += [SpatialTransformer( # always uses a self-attn
263
+ ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
264
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
265
+ use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
266
+ ),
267
+ ResBlock(
268
+ ch,
269
+ time_embed_dim,
270
+ dropout,
271
+ dims=dims,
272
+ use_checkpoint=use_checkpoint,
273
+ use_scale_shift_norm=use_scale_shift_norm,
274
+ dtype=self.dtype,
275
+ device=device,
276
+ operations=operations
277
+ )]
278
+ self.middle_block = TimestepEmbedSequential(*mid_block)
279
+ self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
280
+ self._feature_size += ch
281
+
282
+ def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
283
+ return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
284
+
285
+ def forward(self, x, hint, timesteps, context, y=None, **kwargs):
286
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
287
+ emb = self.time_embed(t_emb)
288
+
289
+ guided_hint = self.input_hint_block(hint, emb, context)
290
+
291
+ outs = []
292
+
293
+ hs = []
294
+ if self.num_classes is not None:
295
+ assert y.shape[0] == x.shape[0]
296
+ emb = emb + self.label_emb(y)
297
+
298
+ h = x
299
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
300
+ if guided_hint is not None:
301
+ h = module(h, emb, context)
302
+ h += guided_hint
303
+ guided_hint = None
304
+ else:
305
+ h = module(h, emb, context)
306
+ outs.append(zero_conv(h, emb, context))
307
+
308
+ h = self.middle_block(h, emb, context)
309
+ outs.append(self.middle_block_out(h, emb, context))
310
+
311
+ return outs
312
+
comfy/cli_args.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import enum
3
+ import comfy.options
4
+
5
+ class EnumAction(argparse.Action):
6
+ """
7
+ Argparse action for handling Enums
8
+ """
9
+ def __init__(self, **kwargs):
10
+ # Pop off the type value
11
+ enum_type = kwargs.pop("type", None)
12
+
13
+ # Ensure an Enum subclass is provided
14
+ if enum_type is None:
15
+ raise ValueError("type must be assigned an Enum when using EnumAction")
16
+ if not issubclass(enum_type, enum.Enum):
17
+ raise TypeError("type must be an Enum when using EnumAction")
18
+
19
+ # Generate choices from the Enum
20
+ choices = tuple(e.value for e in enum_type)
21
+ kwargs.setdefault("choices", choices)
22
+ kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
23
+
24
+ super(EnumAction, self).__init__(**kwargs)
25
+
26
+ self._enum = enum_type
27
+
28
+ def __call__(self, parser, namespace, values, option_string=None):
29
+ # Convert value back into an Enum
30
+ value = self._enum(values)
31
+ setattr(namespace, self.dest, value)
32
+
33
+
34
+ parser = argparse.ArgumentParser()
35
+
36
+ parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
37
+ parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
38
+ parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
39
+ parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
40
+
41
+ parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
42
+ parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
43
+ parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
44
+ parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory.")
45
+ parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
46
+ parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
47
+ parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
48
+ cm_group = parser.add_mutually_exclusive_group()
49
+ cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
50
+ cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
51
+
52
+ parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
53
+
54
+ fp_group = parser.add_mutually_exclusive_group()
55
+ fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
56
+ fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
57
+
58
+ fpunet_group = parser.add_mutually_exclusive_group()
59
+ fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
60
+ fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
61
+ fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
62
+ fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
63
+
64
+ fpvae_group = parser.add_mutually_exclusive_group()
65
+ fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
66
+ fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
67
+ fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
68
+
69
+ parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.")
70
+
71
+ fpte_group = parser.add_mutually_exclusive_group()
72
+ fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
73
+ fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
74
+ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
75
+ fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
76
+
77
+
78
+ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
79
+
80
+ parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
81
+
82
+ class LatentPreviewMethod(enum.Enum):
83
+ NoPreviews = "none"
84
+ Auto = "auto"
85
+ Latent2RGB = "latent2rgb"
86
+ TAESD = "taesd"
87
+
88
+ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
89
+
90
+ attn_group = parser.add_mutually_exclusive_group()
91
+ attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
92
+ attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
93
+ attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
94
+
95
+ parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
96
+
97
+ vram_group = parser.add_mutually_exclusive_group()
98
+ vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
99
+ vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
100
+ vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
101
+ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
102
+ vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
103
+ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
104
+
105
+
106
+ parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
107
+ parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
108
+
109
+ parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
110
+ parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
111
+ parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
112
+
113
+ parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
114
+
115
+ parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
116
+
117
+ if comfy.options.args_parsing:
118
+ args = parser.parse_args()
119
+ else:
120
+ args = parser.parse_args([])
121
+
122
+ if args.windows_standalone_build:
123
+ args.auto_launch = True
124
+
125
+ if args.disable_auto_launch:
126
+ args.auto_launch = False
comfy/clip_config_bigg.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CLIPTextModel"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 0,
7
+ "dropout": 0.0,
8
+ "eos_token_id": 2,
9
+ "hidden_act": "gelu",
10
+ "hidden_size": 1280,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 5120,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 77,
16
+ "model_type": "clip_text_model",
17
+ "num_attention_heads": 20,
18
+ "num_hidden_layers": 32,
19
+ "pad_token_id": 1,
20
+ "projection_dim": 1280,
21
+ "torch_dtype": "float32",
22
+ "vocab_size": 49408
23
+ }
comfy/clip_model.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from comfy.ldm.modules.attention import optimized_attention_for_device
3
+
4
+ class CLIPAttention(torch.nn.Module):
5
+ def __init__(self, embed_dim, heads, dtype, device, operations):
6
+ super().__init__()
7
+
8
+ self.heads = heads
9
+ self.q_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
10
+ self.k_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
11
+ self.v_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
12
+
13
+ self.out_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
14
+
15
+ def forward(self, x, mask=None, optimized_attention=None):
16
+ q = self.q_proj(x)
17
+ k = self.k_proj(x)
18
+ v = self.v_proj(x)
19
+
20
+ out = optimized_attention(q, k, v, self.heads, mask)
21
+ return self.out_proj(out)
22
+
23
+ ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
24
+ "gelu": torch.nn.functional.gelu,
25
+ }
26
+
27
+ class CLIPMLP(torch.nn.Module):
28
+ def __init__(self, embed_dim, intermediate_size, activation, dtype, device, operations):
29
+ super().__init__()
30
+ self.fc1 = operations.Linear(embed_dim, intermediate_size, bias=True, dtype=dtype, device=device)
31
+ self.activation = ACTIVATIONS[activation]
32
+ self.fc2 = operations.Linear(intermediate_size, embed_dim, bias=True, dtype=dtype, device=device)
33
+
34
+ def forward(self, x):
35
+ x = self.fc1(x)
36
+ x = self.activation(x)
37
+ x = self.fc2(x)
38
+ return x
39
+
40
+ class CLIPLayer(torch.nn.Module):
41
+ def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
42
+ super().__init__()
43
+ self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
44
+ self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations)
45
+ self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
46
+ self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device, operations)
47
+
48
+ def forward(self, x, mask=None, optimized_attention=None):
49
+ x += self.self_attn(self.layer_norm1(x), mask, optimized_attention)
50
+ x += self.mlp(self.layer_norm2(x))
51
+ return x
52
+
53
+
54
+ class CLIPEncoder(torch.nn.Module):
55
+ def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
56
+ super().__init__()
57
+ self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
58
+
59
+ def forward(self, x, mask=None, intermediate_output=None):
60
+ optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
61
+
62
+ if intermediate_output is not None:
63
+ if intermediate_output < 0:
64
+ intermediate_output = len(self.layers) + intermediate_output
65
+
66
+ intermediate = None
67
+ for i, l in enumerate(self.layers):
68
+ x = l(x, mask, optimized_attention)
69
+ if i == intermediate_output:
70
+ intermediate = x.clone()
71
+ return x, intermediate
72
+
73
+ class CLIPEmbeddings(torch.nn.Module):
74
+ def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
75
+ super().__init__()
76
+ self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
77
+ self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
78
+
79
+ def forward(self, input_tokens):
80
+ return self.token_embedding(input_tokens) + self.position_embedding.weight
81
+
82
+
83
+ class CLIPTextModel_(torch.nn.Module):
84
+ def __init__(self, config_dict, dtype, device, operations):
85
+ num_layers = config_dict["num_hidden_layers"]
86
+ embed_dim = config_dict["hidden_size"]
87
+ heads = config_dict["num_attention_heads"]
88
+ intermediate_size = config_dict["intermediate_size"]
89
+ intermediate_activation = config_dict["hidden_act"]
90
+
91
+ super().__init__()
92
+ self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
93
+ self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
94
+ self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
95
+
96
+ def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
97
+ x = self.embeddings(input_tokens)
98
+ mask = None
99
+ if attention_mask is not None:
100
+ mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
101
+ mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
102
+
103
+ causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
104
+ if mask is not None:
105
+ mask += causal_mask
106
+ else:
107
+ mask = causal_mask
108
+
109
+ x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output)
110
+ x = self.final_layer_norm(x)
111
+ if i is not None and final_layer_norm_intermediate:
112
+ i = self.final_layer_norm(i)
113
+
114
+ pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
115
+ return x, i, pooled_output
116
+
117
+ class CLIPTextModel(torch.nn.Module):
118
+ def __init__(self, config_dict, dtype, device, operations):
119
+ super().__init__()
120
+ self.num_layers = config_dict["num_hidden_layers"]
121
+ self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
122
+ self.dtype = dtype
123
+
124
+ def get_input_embeddings(self):
125
+ return self.text_model.embeddings.token_embedding
126
+
127
+ def set_input_embeddings(self, embeddings):
128
+ self.text_model.embeddings.token_embedding = embeddings
129
+
130
+ def forward(self, *args, **kwargs):
131
+ return self.text_model(*args, **kwargs)
132
+
133
+ class CLIPVisionEmbeddings(torch.nn.Module):
134
+ def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
135
+ super().__init__()
136
+ self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
137
+
138
+ self.patch_embedding = operations.Conv2d(
139
+ in_channels=num_channels,
140
+ out_channels=embed_dim,
141
+ kernel_size=patch_size,
142
+ stride=patch_size,
143
+ bias=False,
144
+ dtype=dtype,
145
+ device=device
146
+ )
147
+
148
+ num_patches = (image_size // patch_size) ** 2
149
+ num_positions = num_patches + 1
150
+ self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
151
+
152
+ def forward(self, pixel_values):
153
+ embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
154
+ return torch.cat([self.class_embedding.to(embeds.device).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight.to(embeds.device)
155
+
156
+
157
+ class CLIPVision(torch.nn.Module):
158
+ def __init__(self, config_dict, dtype, device, operations):
159
+ super().__init__()
160
+ num_layers = config_dict["num_hidden_layers"]
161
+ embed_dim = config_dict["hidden_size"]
162
+ heads = config_dict["num_attention_heads"]
163
+ intermediate_size = config_dict["intermediate_size"]
164
+ intermediate_activation = config_dict["hidden_act"]
165
+
166
+ self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=torch.float32, device=device, operations=operations)
167
+ self.pre_layrnorm = operations.LayerNorm(embed_dim)
168
+ self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
169
+ self.post_layernorm = operations.LayerNorm(embed_dim)
170
+
171
+ def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
172
+ x = self.embeddings(pixel_values)
173
+ x = self.pre_layrnorm(x)
174
+ #TODO: attention_mask?
175
+ x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
176
+ pooled_output = self.post_layernorm(x[:, 0, :])
177
+ return x, i, pooled_output
178
+
179
+ class CLIPVisionModelProjection(torch.nn.Module):
180
+ def __init__(self, config_dict, dtype, device, operations):
181
+ super().__init__()
182
+ self.vision_model = CLIPVision(config_dict, dtype, device, operations)
183
+ self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
184
+
185
+ def forward(self, *args, **kwargs):
186
+ x = self.vision_model(*args, **kwargs)
187
+ out = self.visual_projection(x[2])
188
+ return (x[0], x[1], out)
comfy/clip_vision.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
2
+ import os
3
+ import torch
4
+ import json
5
+
6
+ import comfy.ops
7
+ import comfy.model_patcher
8
+ import comfy.model_management
9
+ import comfy.utils
10
+ import comfy.clip_model
11
+
12
+ class Output:
13
+ def __getitem__(self, key):
14
+ return getattr(self, key)
15
+ def __setitem__(self, key, item):
16
+ setattr(self, key, item)
17
+
18
+ def clip_preprocess(image, size=224):
19
+ mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
20
+ std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
21
+ image = image.movedim(-1, 1)
22
+ if not (image.shape[2] == size and image.shape[3] == size):
23
+ scale = (size / min(image.shape[2], image.shape[3]))
24
+ image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True)
25
+ h = (image.shape[2] - size)//2
26
+ w = (image.shape[3] - size)//2
27
+ image = image[:,:,h:h+size,w:w+size]
28
+ image = torch.clip((255. * image), 0, 255).round() / 255.0
29
+ return (image - mean.view([3,1,1])) / std.view([3,1,1])
30
+
31
+ class ClipVisionModel():
32
+ def __init__(self, json_config):
33
+ with open(json_config) as f:
34
+ config = json.load(f)
35
+
36
+ self.load_device = comfy.model_management.text_encoder_device()
37
+ offload_device = comfy.model_management.text_encoder_offload_device()
38
+ self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
39
+ self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast)
40
+ self.model.eval()
41
+
42
+ self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
43
+
44
+ def load_sd(self, sd):
45
+ return self.model.load_state_dict(sd, strict=False)
46
+
47
+ def get_sd(self):
48
+ return self.model.state_dict()
49
+
50
+ def encode_image(self, image):
51
+ comfy.model_management.load_model_gpu(self.patcher)
52
+ pixel_values = clip_preprocess(image.to(self.load_device)).float()
53
+ out = self.model(pixel_values=pixel_values, intermediate_output=-2)
54
+
55
+ outputs = Output()
56
+ outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
57
+ outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
58
+ outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
59
+ return outputs
60
+
61
+ def convert_to_transformers(sd, prefix):
62
+ sd_k = sd.keys()
63
+ if "{}transformer.resblocks.0.attn.in_proj_weight".format(prefix) in sd_k:
64
+ keys_to_replace = {
65
+ "{}class_embedding".format(prefix): "vision_model.embeddings.class_embedding",
66
+ "{}conv1.weight".format(prefix): "vision_model.embeddings.patch_embedding.weight",
67
+ "{}positional_embedding".format(prefix): "vision_model.embeddings.position_embedding.weight",
68
+ "{}ln_post.bias".format(prefix): "vision_model.post_layernorm.bias",
69
+ "{}ln_post.weight".format(prefix): "vision_model.post_layernorm.weight",
70
+ "{}ln_pre.bias".format(prefix): "vision_model.pre_layrnorm.bias",
71
+ "{}ln_pre.weight".format(prefix): "vision_model.pre_layrnorm.weight",
72
+ }
73
+
74
+ for x in keys_to_replace:
75
+ if x in sd_k:
76
+ sd[keys_to_replace[x]] = sd.pop(x)
77
+
78
+ if "{}proj".format(prefix) in sd_k:
79
+ sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1)
80
+
81
+ sd = transformers_convert(sd, prefix, "vision_model.", 48)
82
+ else:
83
+ replace_prefix = {prefix: ""}
84
+ sd = state_dict_prefix_replace(sd, replace_prefix)
85
+ return sd
86
+
87
+ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
88
+ if convert_keys:
89
+ sd = convert_to_transformers(sd, prefix)
90
+ if "vision_model.encoder.layers.47.layer_norm1.weight" in sd:
91
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
92
+ elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
93
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
94
+ elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
95
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
96
+ else:
97
+ return None
98
+
99
+ clip = ClipVisionModel(json_config)
100
+ m, u = clip.load_sd(sd)
101
+ if len(m) > 0:
102
+ print("missing clip vision:", m)
103
+ u = set(u)
104
+ keys = list(sd.keys())
105
+ for k in keys:
106
+ if k not in u:
107
+ t = sd.pop(k)
108
+ del t
109
+ return clip
110
+
111
+ def load(ckpt_path):
112
+ sd = load_torch_file(ckpt_path)
113
+ if "visual.transformer.resblocks.0.attn.in_proj_weight" in sd:
114
+ return load_clipvision_from_sd(sd, prefix="visual.", convert_keys=True)
115
+ else:
116
+ return load_clipvision_from_sd(sd)
comfy/clip_vision_config_g.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_dropout": 0.0,
3
+ "dropout": 0.0,
4
+ "hidden_act": "gelu",
5
+ "hidden_size": 1664,
6
+ "image_size": 224,
7
+ "initializer_factor": 1.0,
8
+ "initializer_range": 0.02,
9
+ "intermediate_size": 8192,
10
+ "layer_norm_eps": 1e-05,
11
+ "model_type": "clip_vision_model",
12
+ "num_attention_heads": 16,
13
+ "num_channels": 3,
14
+ "num_hidden_layers": 48,
15
+ "patch_size": 14,
16
+ "projection_dim": 1280,
17
+ "torch_dtype": "float32"
18
+ }
comfy/clip_vision_config_h.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_dropout": 0.0,
3
+ "dropout": 0.0,
4
+ "hidden_act": "gelu",
5
+ "hidden_size": 1280,
6
+ "image_size": 224,
7
+ "initializer_factor": 1.0,
8
+ "initializer_range": 0.02,
9
+ "intermediate_size": 5120,
10
+ "layer_norm_eps": 1e-05,
11
+ "model_type": "clip_vision_model",
12
+ "num_attention_heads": 16,
13
+ "num_channels": 3,
14
+ "num_hidden_layers": 32,
15
+ "patch_size": 14,
16
+ "projection_dim": 1024,
17
+ "torch_dtype": "float32"
18
+ }
comfy/clip_vision_config_vitl.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_dropout": 0.0,
3
+ "dropout": 0.0,
4
+ "hidden_act": "quick_gelu",
5
+ "hidden_size": 1024,
6
+ "image_size": 224,
7
+ "initializer_factor": 1.0,
8
+ "initializer_range": 0.02,
9
+ "intermediate_size": 4096,
10
+ "layer_norm_eps": 1e-05,
11
+ "model_type": "clip_vision_model",
12
+ "num_attention_heads": 16,
13
+ "num_channels": 3,
14
+ "num_hidden_layers": 24,
15
+ "patch_size": 14,
16
+ "projection_dim": 768,
17
+ "torch_dtype": "float32"
18
+ }
comfy/conds.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import comfy.utils
4
+
5
+
6
+ def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
7
+ return abs(a*b) // math.gcd(a, b)
8
+
9
+ class CONDRegular:
10
+ def __init__(self, cond):
11
+ self.cond = cond
12
+
13
+ def _copy_with(self, cond):
14
+ return self.__class__(cond)
15
+
16
+ def process_cond(self, batch_size, device, **kwargs):
17
+ return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
18
+
19
+ def can_concat(self, other):
20
+ if self.cond.shape != other.cond.shape:
21
+ return False
22
+ return True
23
+
24
+ def concat(self, others):
25
+ conds = [self.cond]
26
+ for x in others:
27
+ conds.append(x.cond)
28
+ return torch.cat(conds)
29
+
30
+ class CONDNoiseShape(CONDRegular):
31
+ def process_cond(self, batch_size, device, area, **kwargs):
32
+ data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
33
+ return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
34
+
35
+
36
+ class CONDCrossAttn(CONDRegular):
37
+ def can_concat(self, other):
38
+ s1 = self.cond.shape
39
+ s2 = other.cond.shape
40
+ if s1 != s2:
41
+ if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
42
+ return False
43
+
44
+ mult_min = lcm(s1[1], s2[1])
45
+ diff = mult_min // min(s1[1], s2[1])
46
+ if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
47
+ return False
48
+ return True
49
+
50
+ def concat(self, others):
51
+ conds = [self.cond]
52
+ crossattn_max_len = self.cond.shape[1]
53
+ for x in others:
54
+ c = x.cond
55
+ crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
56
+ conds.append(c)
57
+
58
+ out = []
59
+ for c in conds:
60
+ if c.shape[1] < crossattn_max_len:
61
+ c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
62
+ out.append(c)
63
+ return torch.cat(out)
64
+
65
+ class CONDConstant(CONDRegular):
66
+ def __init__(self, cond):
67
+ self.cond = cond
68
+
69
+ def process_cond(self, batch_size, device, **kwargs):
70
+ return self._copy_with(self.cond)
71
+
72
+ def can_concat(self, other):
73
+ if self.cond != other.cond:
74
+ return False
75
+ return True
76
+
77
+ def concat(self, others):
78
+ return self.cond
comfy/controlnet.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import os
4
+ import comfy.utils
5
+ import comfy.model_management
6
+ import comfy.model_detection
7
+ import comfy.model_patcher
8
+ import comfy.ops
9
+
10
+ import comfy.cldm.cldm
11
+ import comfy.t2i_adapter.adapter
12
+
13
+
14
+ def broadcast_image_to(tensor, target_batch_size, batched_number):
15
+ current_batch_size = tensor.shape[0]
16
+ #print(current_batch_size, target_batch_size)
17
+ if current_batch_size == 1:
18
+ return tensor
19
+
20
+ per_batch = target_batch_size // batched_number
21
+ tensor = tensor[:per_batch]
22
+
23
+ if per_batch > tensor.shape[0]:
24
+ tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0)
25
+
26
+ current_batch_size = tensor.shape[0]
27
+ if current_batch_size == target_batch_size:
28
+ return tensor
29
+ else:
30
+ return torch.cat([tensor] * batched_number, dim=0)
31
+
32
+ class ControlBase:
33
+ def __init__(self, device=None):
34
+ self.cond_hint_original = None
35
+ self.cond_hint = None
36
+ self.strength = 1.0
37
+ self.timestep_percent_range = (0.0, 1.0)
38
+ self.global_average_pooling = False
39
+ self.timestep_range = None
40
+
41
+ if device is None:
42
+ device = comfy.model_management.get_torch_device()
43
+ self.device = device
44
+ self.previous_controlnet = None
45
+
46
+ def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)):
47
+ self.cond_hint_original = cond_hint
48
+ self.strength = strength
49
+ self.timestep_percent_range = timestep_percent_range
50
+ return self
51
+
52
+ def pre_run(self, model, percent_to_timestep_function):
53
+ self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1]))
54
+ if self.previous_controlnet is not None:
55
+ self.previous_controlnet.pre_run(model, percent_to_timestep_function)
56
+
57
+ def set_previous_controlnet(self, controlnet):
58
+ self.previous_controlnet = controlnet
59
+ return self
60
+
61
+ def cleanup(self):
62
+ if self.previous_controlnet is not None:
63
+ self.previous_controlnet.cleanup()
64
+ if self.cond_hint is not None:
65
+ del self.cond_hint
66
+ self.cond_hint = None
67
+ self.timestep_range = None
68
+
69
+ def get_models(self):
70
+ out = []
71
+ if self.previous_controlnet is not None:
72
+ out += self.previous_controlnet.get_models()
73
+ return out
74
+
75
+ def copy_to(self, c):
76
+ c.cond_hint_original = self.cond_hint_original
77
+ c.strength = self.strength
78
+ c.timestep_percent_range = self.timestep_percent_range
79
+ c.global_average_pooling = self.global_average_pooling
80
+
81
+ def inference_memory_requirements(self, dtype):
82
+ if self.previous_controlnet is not None:
83
+ return self.previous_controlnet.inference_memory_requirements(dtype)
84
+ return 0
85
+
86
+ def control_merge(self, control_input, control_output, control_prev, output_dtype):
87
+ out = {'input':[], 'middle':[], 'output': []}
88
+
89
+ if control_input is not None:
90
+ for i in range(len(control_input)):
91
+ key = 'input'
92
+ x = control_input[i]
93
+ if x is not None:
94
+ x *= self.strength
95
+ if x.dtype != output_dtype:
96
+ x = x.to(output_dtype)
97
+ out[key].insert(0, x)
98
+
99
+ if control_output is not None:
100
+ for i in range(len(control_output)):
101
+ if i == (len(control_output) - 1):
102
+ key = 'middle'
103
+ index = 0
104
+ else:
105
+ key = 'output'
106
+ index = i
107
+ x = control_output[i]
108
+ if x is not None:
109
+ if self.global_average_pooling:
110
+ x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
111
+
112
+ x *= self.strength
113
+ if x.dtype != output_dtype:
114
+ x = x.to(output_dtype)
115
+
116
+ out[key].append(x)
117
+ if control_prev is not None:
118
+ for x in ['input', 'middle', 'output']:
119
+ o = out[x]
120
+ for i in range(len(control_prev[x])):
121
+ prev_val = control_prev[x][i]
122
+ if i >= len(o):
123
+ o.append(prev_val)
124
+ elif prev_val is not None:
125
+ if o[i] is None:
126
+ o[i] = prev_val
127
+ else:
128
+ if o[i].shape[0] < prev_val.shape[0]:
129
+ o[i] = prev_val + o[i]
130
+ else:
131
+ o[i] += prev_val
132
+ return out
133
+
134
+ class ControlNet(ControlBase):
135
+ def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
136
+ super().__init__(device)
137
+ self.control_model = control_model
138
+ self.load_device = load_device
139
+ self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
140
+ self.global_average_pooling = global_average_pooling
141
+ self.model_sampling_current = None
142
+ self.manual_cast_dtype = manual_cast_dtype
143
+
144
+ def get_control(self, x_noisy, t, cond, batched_number):
145
+ control_prev = None
146
+ if self.previous_controlnet is not None:
147
+ control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
148
+
149
+ if self.timestep_range is not None:
150
+ if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
151
+ if control_prev is not None:
152
+ return control_prev
153
+ else:
154
+ return None
155
+
156
+ dtype = self.control_model.dtype
157
+ if self.manual_cast_dtype is not None:
158
+ dtype = self.manual_cast_dtype
159
+
160
+ output_dtype = x_noisy.dtype
161
+ if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
162
+ if self.cond_hint is not None:
163
+ del self.cond_hint
164
+ self.cond_hint = None
165
+ self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
166
+ if x_noisy.shape[0] != self.cond_hint.shape[0]:
167
+ self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
168
+
169
+ context = cond.get('crossattn_controlnet', cond['c_crossattn'])
170
+ y = cond.get('y', None)
171
+ if y is not None:
172
+ y = y.to(dtype)
173
+ timestep = self.model_sampling_current.timestep(t)
174
+ x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
175
+
176
+ control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
177
+ return self.control_merge(None, control, control_prev, output_dtype)
178
+
179
+ def copy(self):
180
+ c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
181
+ self.copy_to(c)
182
+ return c
183
+
184
+ def get_models(self):
185
+ out = super().get_models()
186
+ out.append(self.control_model_wrapped)
187
+ return out
188
+
189
+ def pre_run(self, model, percent_to_timestep_function):
190
+ super().pre_run(model, percent_to_timestep_function)
191
+ self.model_sampling_current = model.model_sampling
192
+
193
+ def cleanup(self):
194
+ self.model_sampling_current = None
195
+ super().cleanup()
196
+
197
+ class ControlLoraOps:
198
+ class Linear(torch.nn.Module):
199
+ def __init__(self, in_features: int, out_features: int, bias: bool = True,
200
+ device=None, dtype=None) -> None:
201
+ factory_kwargs = {'device': device, 'dtype': dtype}
202
+ super().__init__()
203
+ self.in_features = in_features
204
+ self.out_features = out_features
205
+ self.weight = None
206
+ self.up = None
207
+ self.down = None
208
+ self.bias = None
209
+
210
+ def forward(self, input):
211
+ weight, bias = comfy.ops.cast_bias_weight(self, input)
212
+ if self.up is not None:
213
+ return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
214
+ else:
215
+ return torch.nn.functional.linear(input, weight, bias)
216
+
217
+ class Conv2d(torch.nn.Module):
218
+ def __init__(
219
+ self,
220
+ in_channels,
221
+ out_channels,
222
+ kernel_size,
223
+ stride=1,
224
+ padding=0,
225
+ dilation=1,
226
+ groups=1,
227
+ bias=True,
228
+ padding_mode='zeros',
229
+ device=None,
230
+ dtype=None
231
+ ):
232
+ super().__init__()
233
+ self.in_channels = in_channels
234
+ self.out_channels = out_channels
235
+ self.kernel_size = kernel_size
236
+ self.stride = stride
237
+ self.padding = padding
238
+ self.dilation = dilation
239
+ self.transposed = False
240
+ self.output_padding = 0
241
+ self.groups = groups
242
+ self.padding_mode = padding_mode
243
+
244
+ self.weight = None
245
+ self.bias = None
246
+ self.up = None
247
+ self.down = None
248
+
249
+
250
+ def forward(self, input):
251
+ weight, bias = comfy.ops.cast_bias_weight(self, input)
252
+ if self.up is not None:
253
+ return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
254
+ else:
255
+ return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
256
+
257
+
258
+ class ControlLora(ControlNet):
259
+ def __init__(self, control_weights, global_average_pooling=False, device=None):
260
+ ControlBase.__init__(self, device)
261
+ self.control_weights = control_weights
262
+ self.global_average_pooling = global_average_pooling
263
+
264
+ def pre_run(self, model, percent_to_timestep_function):
265
+ super().pre_run(model, percent_to_timestep_function)
266
+ controlnet_config = model.model_config.unet_config.copy()
267
+ controlnet_config.pop("out_channels")
268
+ controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
269
+ self.manual_cast_dtype = model.manual_cast_dtype
270
+ dtype = model.get_dtype()
271
+ if self.manual_cast_dtype is None:
272
+ class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init):
273
+ pass
274
+ else:
275
+ class control_lora_ops(ControlLoraOps, comfy.ops.manual_cast):
276
+ pass
277
+ dtype = self.manual_cast_dtype
278
+
279
+ controlnet_config["operations"] = control_lora_ops
280
+ controlnet_config["dtype"] = dtype
281
+ self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
282
+ self.control_model.to(comfy.model_management.get_torch_device())
283
+ diffusion_model = model.diffusion_model
284
+ sd = diffusion_model.state_dict()
285
+ cm = self.control_model.state_dict()
286
+
287
+ for k in sd:
288
+ weight = sd[k]
289
+ try:
290
+ comfy.utils.set_attr(self.control_model, k, weight)
291
+ except:
292
+ pass
293
+
294
+ for k in self.control_weights:
295
+ if k not in {"lora_controlnet"}:
296
+ comfy.utils.set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
297
+
298
+ def copy(self):
299
+ c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
300
+ self.copy_to(c)
301
+ return c
302
+
303
+ def cleanup(self):
304
+ del self.control_model
305
+ self.control_model = None
306
+ super().cleanup()
307
+
308
+ def get_models(self):
309
+ out = ControlBase.get_models(self)
310
+ return out
311
+
312
+ def inference_memory_requirements(self, dtype):
313
+ return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
314
+
315
+ def load_controlnet(ckpt_path, model=None):
316
+ controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
317
+ if "lora_controlnet" in controlnet_data:
318
+ return ControlLora(controlnet_data)
319
+
320
+ controlnet_config = None
321
+ supported_inference_dtypes = None
322
+
323
+ if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
324
+ controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data)
325
+ diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
326
+ diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
327
+ diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
328
+
329
+ count = 0
330
+ loop = True
331
+ while loop:
332
+ suffix = [".weight", ".bias"]
333
+ for s in suffix:
334
+ k_in = "controlnet_down_blocks.{}{}".format(count, s)
335
+ k_out = "zero_convs.{}.0{}".format(count, s)
336
+ if k_in not in controlnet_data:
337
+ loop = False
338
+ break
339
+ diffusers_keys[k_in] = k_out
340
+ count += 1
341
+
342
+ count = 0
343
+ loop = True
344
+ while loop:
345
+ suffix = [".weight", ".bias"]
346
+ for s in suffix:
347
+ if count == 0:
348
+ k_in = "controlnet_cond_embedding.conv_in{}".format(s)
349
+ else:
350
+ k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
351
+ k_out = "input_hint_block.{}{}".format(count * 2, s)
352
+ if k_in not in controlnet_data:
353
+ k_in = "controlnet_cond_embedding.conv_out{}".format(s)
354
+ loop = False
355
+ diffusers_keys[k_in] = k_out
356
+ count += 1
357
+
358
+ new_sd = {}
359
+ for k in diffusers_keys:
360
+ if k in controlnet_data:
361
+ new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
362
+
363
+ leftover_keys = controlnet_data.keys()
364
+ if len(leftover_keys) > 0:
365
+ print("leftover keys:", leftover_keys)
366
+ controlnet_data = new_sd
367
+
368
+ pth_key = 'control_model.zero_convs.0.0.weight'
369
+ pth = False
370
+ key = 'zero_convs.0.0.weight'
371
+ if pth_key in controlnet_data:
372
+ pth = True
373
+ key = pth_key
374
+ prefix = "control_model."
375
+ elif key in controlnet_data:
376
+ prefix = ""
377
+ else:
378
+ net = load_t2i_adapter(controlnet_data)
379
+ if net is None:
380
+ print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path)
381
+ return net
382
+
383
+ if controlnet_config is None:
384
+ model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
385
+ supported_inference_dtypes = model_config.supported_inference_dtypes
386
+ controlnet_config = model_config.unet_config
387
+
388
+ load_device = comfy.model_management.get_torch_device()
389
+ if supported_inference_dtypes is None:
390
+ unet_dtype = comfy.model_management.unet_dtype()
391
+ else:
392
+ unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
393
+
394
+ manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
395
+ if manual_cast_dtype is not None:
396
+ controlnet_config["operations"] = comfy.ops.manual_cast
397
+ controlnet_config["dtype"] = unet_dtype
398
+ controlnet_config.pop("out_channels")
399
+ controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
400
+ control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
401
+
402
+ if pth:
403
+ if 'difference' in controlnet_data:
404
+ if model is not None:
405
+ comfy.model_management.load_models_gpu([model])
406
+ model_sd = model.model_state_dict()
407
+ for x in controlnet_data:
408
+ c_m = "control_model."
409
+ if x.startswith(c_m):
410
+ sd_key = "diffusion_model.{}".format(x[len(c_m):])
411
+ if sd_key in model_sd:
412
+ cd = controlnet_data[x]
413
+ cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
414
+ else:
415
+ print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
416
+
417
+ class WeightsLoader(torch.nn.Module):
418
+ pass
419
+ w = WeightsLoader()
420
+ w.control_model = control_model
421
+ missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
422
+ else:
423
+ missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
424
+ print(missing, unexpected)
425
+
426
+ global_average_pooling = False
427
+ filename = os.path.splitext(ckpt_path)[0]
428
+ if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
429
+ global_average_pooling = True
430
+
431
+ control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
432
+ return control
433
+
434
+ class T2IAdapter(ControlBase):
435
+ def __init__(self, t2i_model, channels_in, device=None):
436
+ super().__init__(device)
437
+ self.t2i_model = t2i_model
438
+ self.channels_in = channels_in
439
+ self.control_input = None
440
+
441
+ def scale_image_to(self, width, height):
442
+ unshuffle_amount = self.t2i_model.unshuffle_amount
443
+ width = math.ceil(width / unshuffle_amount) * unshuffle_amount
444
+ height = math.ceil(height / unshuffle_amount) * unshuffle_amount
445
+ return width, height
446
+
447
+ def get_control(self, x_noisy, t, cond, batched_number):
448
+ control_prev = None
449
+ if self.previous_controlnet is not None:
450
+ control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
451
+
452
+ if self.timestep_range is not None:
453
+ if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
454
+ if control_prev is not None:
455
+ return control_prev
456
+ else:
457
+ return None
458
+
459
+ if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
460
+ if self.cond_hint is not None:
461
+ del self.cond_hint
462
+ self.control_input = None
463
+ self.cond_hint = None
464
+ width, height = self.scale_image_to(x_noisy.shape[3] * 8, x_noisy.shape[2] * 8)
465
+ self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float().to(self.device)
466
+ if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
467
+ self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
468
+ if x_noisy.shape[0] != self.cond_hint.shape[0]:
469
+ self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
470
+ if self.control_input is None:
471
+ self.t2i_model.to(x_noisy.dtype)
472
+ self.t2i_model.to(self.device)
473
+ self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
474
+ self.t2i_model.cpu()
475
+
476
+ control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input))
477
+ mid = None
478
+ if self.t2i_model.xl == True:
479
+ mid = control_input[-1:]
480
+ control_input = control_input[:-1]
481
+ return self.control_merge(control_input, mid, control_prev, x_noisy.dtype)
482
+
483
+ def copy(self):
484
+ c = T2IAdapter(self.t2i_model, self.channels_in)
485
+ self.copy_to(c)
486
+ return c
487
+
488
+ def load_t2i_adapter(t2i_data):
489
+ if 'adapter' in t2i_data:
490
+ t2i_data = t2i_data['adapter']
491
+ if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format
492
+ prefix_replace = {}
493
+ for i in range(4):
494
+ for j in range(2):
495
+ prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
496
+ prefix_replace["adapter.body.{}.".format(i, j)] = "body.{}.".format(i * 2)
497
+ prefix_replace["adapter."] = ""
498
+ t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
499
+ keys = t2i_data.keys()
500
+
501
+ if "body.0.in_conv.weight" in keys:
502
+ cin = t2i_data['body.0.in_conv.weight'].shape[1]
503
+ model_ad = comfy.t2i_adapter.adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
504
+ elif 'conv_in.weight' in keys:
505
+ cin = t2i_data['conv_in.weight'].shape[1]
506
+ channel = t2i_data['conv_in.weight'].shape[0]
507
+ ksize = t2i_data['body.0.block2.weight'].shape[2]
508
+ use_conv = False
509
+ down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys))
510
+ if len(down_opts) > 0:
511
+ use_conv = True
512
+ xl = False
513
+ if cin == 256 or cin == 768:
514
+ xl = True
515
+ model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
516
+ else:
517
+ return None
518
+ missing, unexpected = model_ad.load_state_dict(t2i_data)
519
+ if len(missing) > 0:
520
+ print("t2i missing", missing)
521
+
522
+ if len(unexpected) > 0:
523
+ print("t2i unexpected", unexpected)
524
+
525
+ return T2IAdapter(model_ad, model_ad.input_channels)
comfy/diffusers_convert.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+
4
+ # conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
5
+
6
+ # =================#
7
+ # UNet Conversion #
8
+ # =================#
9
+
10
+ unet_conversion_map = [
11
+ # (stable-diffusion, HF Diffusers)
12
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
13
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
14
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
15
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
16
+ ("input_blocks.0.0.weight", "conv_in.weight"),
17
+ ("input_blocks.0.0.bias", "conv_in.bias"),
18
+ ("out.0.weight", "conv_norm_out.weight"),
19
+ ("out.0.bias", "conv_norm_out.bias"),
20
+ ("out.2.weight", "conv_out.weight"),
21
+ ("out.2.bias", "conv_out.bias"),
22
+ ]
23
+
24
+ unet_conversion_map_resnet = [
25
+ # (stable-diffusion, HF Diffusers)
26
+ ("in_layers.0", "norm1"),
27
+ ("in_layers.2", "conv1"),
28
+ ("out_layers.0", "norm2"),
29
+ ("out_layers.3", "conv2"),
30
+ ("emb_layers.1", "time_emb_proj"),
31
+ ("skip_connection", "conv_shortcut"),
32
+ ]
33
+
34
+ unet_conversion_map_layer = []
35
+ # hardcoded number of downblocks and resnets/attentions...
36
+ # would need smarter logic for other networks.
37
+ for i in range(4):
38
+ # loop over downblocks/upblocks
39
+
40
+ for j in range(2):
41
+ # loop over resnets/attentions for downblocks
42
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
43
+ sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
44
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
45
+
46
+ if i < 3:
47
+ # no attention layers in down_blocks.3
48
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
49
+ sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
50
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
51
+
52
+ for j in range(3):
53
+ # loop over resnets/attentions for upblocks
54
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
55
+ sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
56
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
57
+
58
+ if i > 0:
59
+ # no attention layers in up_blocks.0
60
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
61
+ sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
62
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
63
+
64
+ if i < 3:
65
+ # no downsample in down_blocks.3
66
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
67
+ sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
68
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
69
+
70
+ # no upsample in up_blocks.3
71
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
72
+ sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
73
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
74
+
75
+ hf_mid_atn_prefix = "mid_block.attentions.0."
76
+ sd_mid_atn_prefix = "middle_block.1."
77
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
78
+
79
+ for j in range(2):
80
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
81
+ sd_mid_res_prefix = f"middle_block.{2 * j}."
82
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
83
+
84
+
85
+ def convert_unet_state_dict(unet_state_dict):
86
+ # buyer beware: this is a *brittle* function,
87
+ # and correct output requires that all of these pieces interact in
88
+ # the exact order in which I have arranged them.
89
+ mapping = {k: k for k in unet_state_dict.keys()}
90
+ for sd_name, hf_name in unet_conversion_map:
91
+ mapping[hf_name] = sd_name
92
+ for k, v in mapping.items():
93
+ if "resnets" in k:
94
+ for sd_part, hf_part in unet_conversion_map_resnet:
95
+ v = v.replace(hf_part, sd_part)
96
+ mapping[k] = v
97
+ for k, v in mapping.items():
98
+ for sd_part, hf_part in unet_conversion_map_layer:
99
+ v = v.replace(hf_part, sd_part)
100
+ mapping[k] = v
101
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
102
+ return new_state_dict
103
+
104
+
105
+ # ================#
106
+ # VAE Conversion #
107
+ # ================#
108
+
109
+ vae_conversion_map = [
110
+ # (stable-diffusion, HF Diffusers)
111
+ ("nin_shortcut", "conv_shortcut"),
112
+ ("norm_out", "conv_norm_out"),
113
+ ("mid.attn_1.", "mid_block.attentions.0."),
114
+ ]
115
+
116
+ for i in range(4):
117
+ # down_blocks have two resnets
118
+ for j in range(2):
119
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
120
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
121
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
122
+
123
+ if i < 3:
124
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
125
+ sd_downsample_prefix = f"down.{i}.downsample."
126
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
127
+
128
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
129
+ sd_upsample_prefix = f"up.{3 - i}.upsample."
130
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
131
+
132
+ # up_blocks have three resnets
133
+ # also, up blocks in hf are numbered in reverse from sd
134
+ for j in range(3):
135
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
136
+ sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
137
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
138
+
139
+ # this part accounts for mid blocks in both the encoder and the decoder
140
+ for i in range(2):
141
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
142
+ sd_mid_res_prefix = f"mid.block_{i + 1}."
143
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
144
+
145
+ vae_conversion_map_attn = [
146
+ # (stable-diffusion, HF Diffusers)
147
+ ("norm.", "group_norm."),
148
+ ("q.", "query."),
149
+ ("k.", "key."),
150
+ ("v.", "value."),
151
+ ("q.", "to_q."),
152
+ ("k.", "to_k."),
153
+ ("v.", "to_v."),
154
+ ("proj_out.", "to_out.0."),
155
+ ("proj_out.", "proj_attn."),
156
+ ]
157
+
158
+
159
+ def reshape_weight_for_sd(w):
160
+ # convert HF linear weights to SD conv2d weights
161
+ return w.reshape(*w.shape, 1, 1)
162
+
163
+
164
+ def convert_vae_state_dict(vae_state_dict):
165
+ mapping = {k: k for k in vae_state_dict.keys()}
166
+ for k, v in mapping.items():
167
+ for sd_part, hf_part in vae_conversion_map:
168
+ v = v.replace(hf_part, sd_part)
169
+ mapping[k] = v
170
+ for k, v in mapping.items():
171
+ if "attentions" in k:
172
+ for sd_part, hf_part in vae_conversion_map_attn:
173
+ v = v.replace(hf_part, sd_part)
174
+ mapping[k] = v
175
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
176
+ weights_to_convert = ["q", "k", "v", "proj_out"]
177
+ for k, v in new_state_dict.items():
178
+ for weight_name in weights_to_convert:
179
+ if f"mid.attn_1.{weight_name}.weight" in k:
180
+ print(f"Reshaping {k} for SD format")
181
+ new_state_dict[k] = reshape_weight_for_sd(v)
182
+ return new_state_dict
183
+
184
+
185
+ # =========================#
186
+ # Text Encoder Conversion #
187
+ # =========================#
188
+
189
+
190
+ textenc_conversion_lst = [
191
+ # (stable-diffusion, HF Diffusers)
192
+ ("resblocks.", "text_model.encoder.layers."),
193
+ ("ln_1", "layer_norm1"),
194
+ ("ln_2", "layer_norm2"),
195
+ (".c_fc.", ".fc1."),
196
+ (".c_proj.", ".fc2."),
197
+ (".attn", ".self_attn"),
198
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
199
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
200
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
201
+ ]
202
+ protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
203
+ textenc_pattern = re.compile("|".join(protected.keys()))
204
+
205
+ # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
206
+ code2idx = {"q": 0, "k": 1, "v": 2}
207
+
208
+
209
+ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
210
+ new_state_dict = {}
211
+ capture_qkv_weight = {}
212
+ capture_qkv_bias = {}
213
+ for k, v in text_enc_dict.items():
214
+ if not k.startswith(prefix):
215
+ continue
216
+ if (
217
+ k.endswith(".self_attn.q_proj.weight")
218
+ or k.endswith(".self_attn.k_proj.weight")
219
+ or k.endswith(".self_attn.v_proj.weight")
220
+ ):
221
+ k_pre = k[: -len(".q_proj.weight")]
222
+ k_code = k[-len("q_proj.weight")]
223
+ if k_pre not in capture_qkv_weight:
224
+ capture_qkv_weight[k_pre] = [None, None, None]
225
+ capture_qkv_weight[k_pre][code2idx[k_code]] = v
226
+ continue
227
+
228
+ if (
229
+ k.endswith(".self_attn.q_proj.bias")
230
+ or k.endswith(".self_attn.k_proj.bias")
231
+ or k.endswith(".self_attn.v_proj.bias")
232
+ ):
233
+ k_pre = k[: -len(".q_proj.bias")]
234
+ k_code = k[-len("q_proj.bias")]
235
+ if k_pre not in capture_qkv_bias:
236
+ capture_qkv_bias[k_pre] = [None, None, None]
237
+ capture_qkv_bias[k_pre][code2idx[k_code]] = v
238
+ continue
239
+
240
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
241
+ new_state_dict[relabelled_key] = v
242
+
243
+ for k_pre, tensors in capture_qkv_weight.items():
244
+ if None in tensors:
245
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
246
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
247
+ new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
248
+
249
+ for k_pre, tensors in capture_qkv_bias.items():
250
+ if None in tensors:
251
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
252
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
253
+ new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
254
+
255
+ return new_state_dict
256
+
257
+
258
+ def convert_text_enc_state_dict(text_enc_dict):
259
+ return text_enc_dict
260
+
261
+