def extend_instance(obj, mixin): """Apply mixins to a class instance after creation""" base_cls = obj.__class__ base_cls_name = obj.__class__.__name__ obj.__class__ = type( base_cls_name, (mixin, base_cls), {} ) # mixin needs to go first for our forward() logic to work def getattr_recursive(obj, att): """ Return nested attribute of obj Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c """ if att == "": return obj i = att.find(".") if i < 0: return getattr(obj, att) else: return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :]) def setattr_recursive(obj, att, val): """ Set nested attribute of obj Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val """ if "." in att: obj = getattr_recursive(obj, ".".join(att.split(".")[:-1])) setattr(obj, att.split(".")[-1], val) def _infer_decoder_layers_attr_name(model): for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES: if k.lower() in model.__class__.__name__.lower(): return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k] raise ValueError( "We require the attribute name for the nn.ModuleList in the decoder storing" " the transformer block layers. Please supply this string manually." ) __KNOWN_DECODER_LAYERS_ATTR_NAMES = { "llama": "model.layers", "mistral": "model.layers", } def resize_eva_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1): # interpolate position embedding if "pos_embed" in state_dict: pos_embed_checkpoint = state_dict["pos_embed"] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.patch_embed.num_patches num_extra_tokens = model.pos_embed.shape[-2] - num_patches # height (== width) for the checkpoint position embedding orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) # height (== width) for the new position embedding new_size = int(num_patches**0.5) # class_token and dist_token are kept unchanged if orig_size != new_size: print( "Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size) ) extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape( -1, orig_size, orig_size, embedding_size ).permute(0, 3, 1, 2) # Convert to float for interpolation pos_tokens = pos_tokens.float() pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False, ) # Convert back to Half if needed pos_tokens = pos_tokens.half() pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) state_dict["pos_embed"] = new_pos_embed patch_embed_proj = state_dict["patch_embed.proj.weight"] patch_size = model.patch_embed.patch_size # Convert to float for interpolation patch_embed_proj = patch_embed_proj.float() state_dict["patch_embed.proj.weight"] = torch.nn.functional.interpolate( patch_embed_proj.float(), size=patch_size, mode="bicubic", align_corners=False, ) state_dict["patch_embed.proj.weight"] = state_dict["patch_embed.proj.weight"].half()