Spaces:
Runtime error
Runtime error
from ldm.modules.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder | |
from ldm.modules.embedding_manager import EmbeddingManager | |
import argparse, os | |
from functools import partial | |
import torch | |
def get_placeholder_loop(placeholder_string, embedder, is_sd): | |
new_placeholder = None | |
while True: | |
if new_placeholder is None: | |
new_placeholder = input(f"Placeholder string {placeholder_string} was already used. Please enter a replacement string: ") | |
else: | |
new_placeholder = input(f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: ") | |
token = get_clip_token_for_string(embedder.tokenizer, new_placeholder) if is_sd else get_bert_token_for_string(embedder.tknz_fn, new_placeholder) | |
if token is not None: | |
return new_placeholder, token | |
def get_clip_token_for_string(tokenizer, string): | |
batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True, | |
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") | |
tokens = batch_encoding["input_ids"] | |
if torch.count_nonzero(tokens - 49407) == 2: | |
return tokens[0, 1] | |
return None | |
def get_bert_token_for_string(tokenizer, string): | |
token = tokenizer(string) | |
if torch.count_nonzero(token) == 3: | |
return token[0, 1] | |
return None | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--manager_ckpts", | |
type=str, | |
nargs="+", | |
required=True, | |
help="Paths to a set of embedding managers to be merged." | |
) | |
parser.add_argument( | |
"--output_path", | |
type=str, | |
required=True, | |
help="Output path for the merged manager", | |
) | |
parser.add_argument( | |
"-sd", "--stable_diffusion", | |
action="store_true", | |
help="Flag to denote that we are merging stable diffusion embeddings" | |
) | |
args = parser.parse_args() | |
if args.stable_diffusion: | |
embedder = FrozenCLIPEmbedder().cuda() | |
else: | |
embedder = BERTEmbedder(n_embed=1280, n_layer=32).cuda() | |
EmbeddingManager = partial(EmbeddingManager, embedder, ["*"]) | |
string_to_token_dict = {} | |
string_to_param_dict = torch.nn.ParameterDict() | |
placeholder_to_src = {} | |
for manager_ckpt in args.manager_ckpts: | |
print(f"Parsing {manager_ckpt}...") | |
manager = EmbeddingManager() | |
manager.load(manager_ckpt) | |
for placeholder_string in manager.string_to_token_dict: | |
if not placeholder_string in string_to_token_dict: | |
string_to_token_dict[placeholder_string] = manager.string_to_token_dict[placeholder_string] | |
string_to_param_dict[placeholder_string] = manager.string_to_param_dict[placeholder_string] | |
placeholder_to_src[placeholder_string] = manager_ckpt | |
else: | |
new_placeholder, new_token = get_placeholder_loop(placeholder_string, embedder, is_sd=args.stable_diffusion) | |
string_to_token_dict[new_placeholder] = new_token | |
string_to_param_dict[new_placeholder] = manager.string_to_param_dict[placeholder_string] | |
placeholder_to_src[new_placeholder] = manager_ckpt | |
print("Saving combined manager...") | |
merged_manager = EmbeddingManager() | |
merged_manager.string_to_param_dict = string_to_param_dict | |
merged_manager.string_to_token_dict = string_to_token_dict | |
merged_manager.save(args.output_path) | |
print("Managers merged. Final list of placeholders: ") | |
print(placeholder_to_src) | |