import os from collections import namedtuple from contextlib import closing import torch import tqdm import html import datetime import csv import safetensors.torch import numpy as np from PIL import Image, PngImagePlugin from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"]) textual_inversion_templates = {} def list_textual_inversion_templates(): textual_inversion_templates.clear() for root, _, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir): for fn in fns: path = os.path.join(root, fn) textual_inversion_templates[fn] = TextualInversionTemplate(fn, path) return textual_inversion_templates class Embedding: def __init__(self, vec, name, step=None): self.vec = vec self.name = name self.step = step self.shape = None self.vectors = 0 self.cached_checksum = None self.sd_checkpoint = None self.sd_checkpoint_name = None self.optimizer_state_dict = None self.filename = None self.hash = None self.shorthash = None def save(self, filename): embedding_data = { "string_to_token": {"*": 265}, "string_to_param": {"*": self.vec}, "name": self.name, "step": self.step, "sd_checkpoint": self.sd_checkpoint, "sd_checkpoint_name": self.sd_checkpoint_name, } torch.save(embedding_data, filename) if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None: optimizer_saved_dict = { 'hash': self.checksum(), 'optimizer_state_dict': self.optimizer_state_dict, } torch.save(optimizer_saved_dict, f"{filename}.optim") def checksum(self): if self.cached_checksum is not None: return self.cached_checksum def const_hash(a): r = 0 for v in a: r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF return r self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}' return self.cached_checksum def set_hash(self, v): self.hash = v self.shorthash = self.hash[0:12] class DirWithTextualInversionEmbeddings: def __init__(self, path): self.path = path self.mtime = None def has_changed(self): if not os.path.isdir(self.path): return False mt = os.path.getmtime(self.path) if self.mtime is None or mt > self.mtime: return True def update(self): if not os.path.isdir(self.path): return self.mtime = os.path.getmtime(self.path) class EmbeddingDatabase: def __init__(self): self.ids_lookup = {} self.word_embeddings = {} self.skipped_embeddings = {} self.expected_shape = -1 self.embedding_dirs = {} self.previously_displayed_embeddings = () def add_embedding_dir(self, path): self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path) def clear_embedding_dirs(self): self.embedding_dirs.clear() def register_embedding(self, embedding, model): return self.register_embedding_by_name(embedding, model, embedding.name) def register_embedding_by_name(self, embedding, model, name): ids = [0, 0, 0] # model.cond_stage_model.tokenize([name])[0] first_id = ids[0] if first_id not in self.ids_lookup: self.ids_lookup[first_id] = [] if name in self.word_embeddings: # remove old one from the lookup list lookup = [x for x in self.ids_lookup[first_id] if x[1].name!=name] else: lookup = self.ids_lookup[first_id] if embedding is not None: lookup += [(ids, embedding)] self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True) if embedding is None: # unregister embedding with specified name if name in self.word_embeddings: del self.word_embeddings[name] if len(self.ids_lookup[first_id])==0: del self.ids_lookup[first_id] return None self.word_embeddings[name] = embedding return embedding def get_expected_shape(self): devices.torch_npu_set_device() vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1) return vec.shape[1] def load_from_file(self, path, filename): name, ext = os.path.splitext(filename) ext = ext.upper() if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: _, second_ext = os.path.splitext(name) if second_ext.upper() == '.PREVIEW': return embed_image = Image.open(path) if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: data = embedding_from_b64(embed_image.text['sd-ti-embedding']) name = data.get('name', name) else: data = extract_image_data_embed(embed_image) if data: name = data.get('name', name) else: # if data is None, means this is not an embedding, just a preview image return elif ext in ['.BIN', '.PT']: data = torch.load(path, map_location="cpu") elif ext in ['.SAFETENSORS']: data = safetensors.torch.load_file(path, device="cpu") else: return if data is not None: embedding = create_embedding_from_data(data, name, filename=filename, filepath=path) self.register_embedding(embedding, None) else: print(f"Unable to load Textual inversion embedding due to data issue: '{name}'.") def load_from_dir(self, embdir): if not os.path.isdir(embdir.path): return for root, _, fns in os.walk(embdir.path, followlinks=True): for fn in fns: try: fullfn = os.path.join(root, fn) if os.stat(fullfn).st_size == 0: continue self.load_from_file(fullfn, fn) except Exception: errors.report(f"Error loading embedding {fn}", exc_info=True) continue def load_textual_inversion_embeddings(self, force_reload=False, sync_with_sd_model=True): if not force_reload: need_reload = False for embdir in self.embedding_dirs.values(): if embdir.has_changed(): need_reload = True break if not need_reload: return self.ids_lookup.clear() self.word_embeddings.clear() self.skipped_embeddings.clear() if sync_with_sd_model: self.expected_shape = self.get_expected_shape() for embdir in self.embedding_dirs.values(): self.load_from_dir(embdir) embdir.update() # re-sort word_embeddings because load_from_dir may not load in alphabetic order. # using a temporary copy so we don't reinitialize self.word_embeddings in case other objects have a reference to it. sorted_word_embeddings = {e.name: e for e in sorted(self.word_embeddings.values(), key=lambda e: e.name.lower())} self.word_embeddings.clear() self.word_embeddings.update(sorted_word_embeddings) displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys())) if shared.opts.textual_inversion_print_at_load and self.previously_displayed_embeddings != displayed_embeddings: self.previously_displayed_embeddings = displayed_embeddings print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") if self.skipped_embeddings: print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") def find_embedding_at_position(self, tokens, offset): token = tokens[offset] possible_matches = self.ids_lookup.get(token, None) if possible_matches is None: return None, None for ids, embedding in possible_matches: if tokens[offset:offset + len(ids)] == ids: return embedding, len(ids) return None, None def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): cond_model = shared.sd_model.cond_stage_model with devices.autocast(): cond_model([""]) # will send cond model to GPU if lowvram/medvram is active #cond_model expects at least some text, so we provide '*' as backup. embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token) vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) #Only copy if we provided an init_text, otherwise keep vectors as zeros if init_text: for i in range(num_vectors_per_token): vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] # Remove illegal characters from name. name = "".join( x for x in name if (x.isalnum() or x in "._- ")) fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt") if not overwrite_old: assert not os.path.exists(fn), f"file {fn} already exists" embedding = Embedding(vec, name) embedding.step = 0 embedding.save(fn) return fn def create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None): if 'string_to_param' in data: # textual inversion embeddings param_dict = data['string_to_param'] param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 assert len(param_dict) == 1, 'embedding file has multiple terms in it' emb = next(iter(param_dict.items()))[1] vec = emb.detach().to(devices.device, dtype=torch.float32) shape = vec.shape[-1] vectors = vec.shape[0] elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()} shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1] vectors = data['clip_g'].shape[0] elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts assert len(data.keys()) == 1, 'embedding file has multiple terms in it' emb = next(iter(data.values())) if len(emb.shape) == 1: emb = emb.unsqueeze(0) vec = emb.detach().to(devices.device, dtype=torch.float32) shape = vec.shape[-1] vectors = vec.shape[0] else: raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") embedding = Embedding(vec, name) embedding.step = data.get('step', None) embedding.sd_checkpoint = data.get('sd_checkpoint', None) embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) embedding.vectors = vectors embedding.shape = shape if filepath: embedding.filename = filepath embedding.set_hash(hashes.sha256(filepath, "textual_inversion/" + name) or '') return embedding