ND911's picture
Upload converttoP.py
89eb606 verified
raw
history blame contribute delete
660 Bytes
import numpy as np
import torch
import safetensors
from safetensors.torch import save_file
import matplotlib.pyplot as plt
model = safetensors.safe_open('sd3_medium_incl_clips_t5xxlfp16.safetensors', 'pt')
keys = model.keys()
dic = {key:model.get_tensor(key) for key in keys}
parts = ['diffusion_model']
count = 0
for k in keys:
if all(i in k for i in parts):
v = dic[k]
print(f'{k}: {v.std()}')
dic[k] += torch.normal(torch.zeros_like(v)*v.mean(), torch.ones_like(v)*v.std()*.02)
count += 1
print(count)
save_file(dic, 'sd3_medium_incl_clips_t5xxlfp16.safetensors_perturbed3.safetensors', model.metadata())