Spaces:
Runtime error
Runtime error
import functools | |
import importlib | |
import os | |
from functools import partial | |
from inspect import isfunction | |
import fsspec | |
import numpy as np | |
import torch | |
from PIL import Image, ImageDraw, ImageFont | |
from safetensors.torch import load_file as load_safetensors | |
import torch.distributed | |
_CONTEXT_PARALLEL_GROUP = None | |
_CONTEXT_PARALLEL_SIZE = None | |
def is_context_parallel_initialized(): | |
if _CONTEXT_PARALLEL_GROUP is None: | |
return False | |
else: | |
return True | |
def initialize_context_parallel(context_parallel_size): | |
global _CONTEXT_PARALLEL_GROUP | |
global _CONTEXT_PARALLEL_SIZE | |
assert _CONTEXT_PARALLEL_GROUP is None, "context parallel group is already initialized" | |
_CONTEXT_PARALLEL_SIZE = context_parallel_size | |
rank = torch.distributed.get_rank() | |
world_size = torch.distributed.get_world_size() | |
for i in range(0, world_size, context_parallel_size): | |
ranks = range(i, i + context_parallel_size) | |
group = torch.distributed.new_group(ranks) | |
if rank in ranks: | |
_CONTEXT_PARALLEL_GROUP = group | |
break | |
def get_context_parallel_group(): | |
assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized" | |
return _CONTEXT_PARALLEL_GROUP | |
def get_context_parallel_world_size(): | |
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized" | |
return _CONTEXT_PARALLEL_SIZE | |
def get_context_parallel_rank(): | |
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized" | |
rank = torch.distributed.get_rank() | |
cp_rank = rank % _CONTEXT_PARALLEL_SIZE | |
return cp_rank | |
def get_context_parallel_group_rank(): | |
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized" | |
rank = torch.distributed.get_rank() | |
cp_group_rank = rank // _CONTEXT_PARALLEL_SIZE | |
return cp_group_rank | |
class SafeConv3d(torch.nn.Conv3d): | |
def forward(self, input): | |
memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 | |
if memory_count > 2: | |
kernel_size = self.kernel_size[0] | |
part_num = int(memory_count / 2) + 1 | |
input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW | |
if kernel_size > 1: | |
input_chunks = [input_chunks[0]] + [ | |
torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) | |
for i in range(1, len(input_chunks)) | |
] | |
output_chunks = [] | |
for input_chunk in input_chunks: | |
output_chunks.append(super(SafeConv3d, self).forward(input_chunk)) | |
output = torch.cat(output_chunks, dim=2) | |
return output | |
else: | |
return super(SafeConv3d, self).forward(input) | |
def disabled_train(self, mode=True): | |
"""Overwrite model.train with this function to make sure train/eval mode | |
does not change anymore.""" | |
return self | |
def get_string_from_tuple(s): | |
try: | |
# Check if the string starts and ends with parentheses | |
if s[0] == "(" and s[-1] == ")": | |
# Convert the string to a tuple | |
t = eval(s) | |
# Check if the type of t is tuple | |
if type(t) == tuple: | |
return t[0] | |
else: | |
pass | |
except: | |
pass | |
return s | |
def is_power_of_two(n): | |
""" | |
chat.openai.com/chat | |
Return True if n is a power of 2, otherwise return False. | |
The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False. | |
The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False. | |
If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise. | |
Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False. | |
""" | |
if n <= 0: | |
return False | |
return (n & (n - 1)) == 0 | |
def autocast(f, enabled=True): | |
def do_autocast(*args, **kwargs): | |
with torch.cuda.amp.autocast( | |
enabled=enabled, | |
dtype=torch.get_autocast_gpu_dtype(), | |
cache_enabled=torch.is_autocast_cache_enabled(), | |
): | |
return f(*args, **kwargs) | |
return do_autocast | |
def load_partial_from_config(config): | |
return partial(get_obj_from_str(config["target"]), **config.get("params", dict())) | |
def log_txt_as_img(wh, xc, size=10): | |
# wh a tuple of (width, height) | |
# xc a list of captions to plot | |
b = len(xc) | |
txts = list() | |
for bi in range(b): | |
txt = Image.new("RGB", wh, color="white") | |
draw = ImageDraw.Draw(txt) | |
font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) | |
nc = int(40 * (wh[0] / 256)) | |
if isinstance(xc[bi], list): | |
text_seq = xc[bi][0] | |
else: | |
text_seq = xc[bi] | |
lines = "\n".join(text_seq[start : start + nc] for start in range(0, len(text_seq), nc)) | |
try: | |
draw.text((0, 0), lines, fill="black", font=font) | |
except UnicodeEncodeError: | |
print("Cant encode string for logging. Skipping.") | |
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 | |
txts.append(txt) | |
txts = np.stack(txts) | |
txts = torch.tensor(txts) | |
return txts | |
def partialclass(cls, *args, **kwargs): | |
class NewCls(cls): | |
__init__ = functools.partialmethod(cls.__init__, *args, **kwargs) | |
return NewCls | |
def make_path_absolute(path): | |
fs, p = fsspec.core.url_to_fs(path) | |
if fs.protocol == "file": | |
return os.path.abspath(p) | |
return path | |
def ismap(x): | |
if not isinstance(x, torch.Tensor): | |
return False | |
return (len(x.shape) == 4) and (x.shape[1] > 3) | |
def isimage(x): | |
if not isinstance(x, torch.Tensor): | |
return False | |
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) | |
def isheatmap(x): | |
if not isinstance(x, torch.Tensor): | |
return False | |
return x.ndim == 2 | |
def isneighbors(x): | |
if not isinstance(x, torch.Tensor): | |
return False | |
return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1) | |
def exists(x): | |
return x is not None | |
def expand_dims_like(x, y): | |
while x.dim() != y.dim(): | |
x = x.unsqueeze(-1) | |
return x | |
def default(val, d): | |
if exists(val): | |
return val | |
return d() if isfunction(d) else d | |
def mean_flat(tensor): | |
""" | |
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 | |
Take the mean over all non-batch dimensions. | |
""" | |
return tensor.mean(dim=list(range(1, len(tensor.shape)))) | |
def count_params(model, verbose=False): | |
total_params = sum(p.numel() for p in model.parameters()) | |
if verbose: | |
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") | |
return total_params | |
def instantiate_from_config(config): | |
if not "target" in config: | |
if config == "__is_first_stage__": | |
return None | |
elif config == "__is_unconditional__": | |
return None | |
raise KeyError("Expected key `target` to instantiate.") | |
return get_obj_from_str(config["target"])(**config.get("params", dict())) | |
def get_obj_from_str(string, reload=False, invalidate_cache=True): | |
module, cls = string.rsplit(".", 1) | |
if invalidate_cache: | |
importlib.invalidate_caches() | |
if reload: | |
module_imp = importlib.import_module(module) | |
importlib.reload(module_imp) | |
return getattr(importlib.import_module(module, package=None), cls) | |
def append_zero(x): | |
return torch.cat([x, x.new_zeros([1])]) | |
def append_dims(x, target_dims): | |
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | |
dims_to_append = target_dims - x.ndim | |
if dims_to_append < 0: | |
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") | |
return x[(...,) + (None,) * dims_to_append] | |
def load_model_from_config(config, ckpt, verbose=True, freeze=True): | |
print(f"Loading model from {ckpt}") | |
if ckpt.endswith("ckpt"): | |
pl_sd = torch.load(ckpt, map_location="cpu") | |
if "global_step" in pl_sd: | |
print(f"Global Step: {pl_sd['global_step']}") | |
sd = pl_sd["state_dict"] | |
elif ckpt.endswith("safetensors"): | |
sd = load_safetensors(ckpt) | |
else: | |
raise NotImplementedError | |
model = instantiate_from_config(config.model) | |
m, u = model.load_state_dict(sd, strict=False) | |
if len(m) > 0 and verbose: | |
print("missing keys:") | |
print(m) | |
if len(u) > 0 and verbose: | |
print("unexpected keys:") | |
print(u) | |
if freeze: | |
for param in model.parameters(): | |
param.requires_grad = False | |
model.eval() | |
return model | |
def get_configs_path() -> str: | |
""" | |
Get the `configs` directory. | |
For a working copy, this is the one in the root of the repository, | |
but for an installed copy, it's in the `sgm` package (see pyproject.toml). | |
""" | |
this_dir = os.path.dirname(__file__) | |
candidates = ( | |
os.path.join(this_dir, "configs"), | |
os.path.join(this_dir, "..", "configs"), | |
) | |
for candidate in candidates: | |
candidate = os.path.abspath(candidate) | |
if os.path.isdir(candidate): | |
return candidate | |
raise FileNotFoundError(f"Could not find SGM configs in {candidates}") | |
def get_nested_attribute(obj, attribute_path, depth=None, return_key=False): | |
""" | |
Will return the result of a recursive get attribute call. | |
E.g.: | |
a.b.c | |
= getattr(getattr(a, "b"), "c") | |
= get_nested_attribute(a, "b.c") | |
If any part of the attribute call is an integer x with current obj a, will | |
try to call a[x] instead of a.x first. | |
""" | |
attributes = attribute_path.split(".") | |
if depth is not None and depth > 0: | |
attributes = attributes[:depth] | |
assert len(attributes) > 0, "At least one attribute should be selected" | |
current_attribute = obj | |
current_key = None | |
for level, attribute in enumerate(attributes): | |
current_key = ".".join(attributes[: level + 1]) | |
try: | |
id_ = int(attribute) | |
current_attribute = current_attribute[id_] | |
except ValueError: | |
current_attribute = getattr(current_attribute, attribute) | |
return (current_attribute, current_key) if return_key else current_attribute | |
def checkpoint(func, inputs, params, flag): | |
""" | |
Evaluate a function without caching intermediate activations, allowing for | |
reduced memory at the expense of extra compute in the backward pass. | |
:param func: the function to evaluate. | |
:param inputs: the argument sequence to pass to `func`. | |
:param params: a sequence of parameters `func` depends on but does not | |
explicitly take as arguments. | |
:param flag: if False, disable gradient checkpointing. | |
""" | |
if flag: | |
args = tuple(inputs) + tuple(params) | |
return CheckpointFunction.apply(func, len(inputs), *args) | |
else: | |
return func(*inputs) | |
class CheckpointFunction(torch.autograd.Function): | |
def forward(ctx, run_function, length, *args): | |
ctx.run_function = run_function | |
ctx.input_tensors = list(args[:length]) | |
ctx.input_params = list(args[length:]) | |
ctx.gpu_autocast_kwargs = { | |
"enabled": torch.is_autocast_enabled(), | |
"dtype": torch.get_autocast_gpu_dtype(), | |
"cache_enabled": torch.is_autocast_cache_enabled(), | |
} | |
with torch.no_grad(): | |
output_tensors = ctx.run_function(*ctx.input_tensors) | |
return output_tensors | |
def backward(ctx, *output_grads): | |
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] | |
with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): | |
# Fixes a bug where the first op in run_function modifies the | |
# Tensor storage in place, which is not allowed for detach()'d | |
# Tensors. | |
shallow_copies = [x.view_as(x) for x in ctx.input_tensors] | |
output_tensors = ctx.run_function(*shallow_copies) | |
input_grads = torch.autograd.grad( | |
output_tensors, | |
ctx.input_tensors + ctx.input_params, | |
output_grads, | |
allow_unused=True, | |
) | |
del ctx.input_tensors | |
del ctx.input_params | |
del output_tensors | |
return (None, None) + input_grads | |