File size: 4,904 Bytes
6eca12e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import torch
import numpy as np
import os, sys
import time
class LyraChecker:
def __init__(self, dir_data, tol):
self.dir_data = dir_data
self.tol = tol
def cmp(self, fpath1, fpath2="", tol=0):
tolbk = self.tol
if tol != 0:
self.tol = tol
if fpath2 == "":
fpath2 = fpath1
fpath1 += "_1"
fpath2 += "_2"
v1 = self.get_npy(fpath1) #np.load(os.path.join(self.dir_data, fpath1))
v2 = self.get_npy(fpath2) #np.load(os.path.join(self.dir_data, fpath2))
name = fpath1
if ".npy" in fpath1:
name = ".".join(os.path.basename(fpath1).split(".")[:-1])
self._cmp_inner(v1, v2, name)
self.tol = tolbk
def _cmp_inner(self, v1, v2, name):
print(v1.shape, v2.shape)
if v1.shape != v2.shape:
if v1.shape[1] == v2.shape[1]:
v2 = v2.reshape([v2.shape[0], v2.shape[1], -1])
else:
v2 = torch.tensor(v2).permute(0, 3, 1, 2).numpy()
print(v1.shape, v2.shape)
self._check_data(name, v1, v2)
print(np.size(v1))
def _check_data(self, stage, x_out, x_gt):
print(f"========== {stage} =============")
print(x_out.shape, x_gt.shape)
if np.allclose(x_gt, x_out, atol=self.tol):
print(f"[OK] At {stage}, tol: {self.tol}")
else:
diff_cnt = np.count_nonzero(np.abs(x_gt - x_out)>self.tol)
print(f"[FAIL]At {stage}, not aligned. tol: {self.tol}")
print(" [INFO]Max diff: ", np.max(np.abs(x_gt - x_out)))
print(" [INFO]Diff count: ", diff_cnt, ", ratio: ", round(diff_cnt/np.size(x_out), 2))
print(f">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
def cmp_query(self, fpath1, fpath2):
v1 = np.load(os.path.join(self.dir_data, fpath1))
vk = np.load(os.path.join(self.dir_data, fpath1).replace("query", "key"))
vv = np.load(os.path.join(self.dir_data, fpath1).replace("query", "value"))
v2 = np.load(os.path.join(self.dir_data, fpath2))
# print(v1.shape, v2.shape)
q2 = v2[:,:,0,:,:].transpose([0,2,1,3])
# print(v1.shape, q2.shape)
self.check_data("query", v1, q2)
# print(vk.shape, v2.shape)
k2 = v2[:,:,1,:,:].transpose([0,2,1,3])
self.check_data("key", vk, k2)
vv2 = v2[:,:,2,:,:].transpose([0,2,1,3])
# print(vv.shape, vv2.shape)
self.check_data("value", vv, vv2)
def _get_data_fpath(self, fname):
fpath = os.path.join(self.dir_data, fname)
if not fpath.endswith(".npy"):
fpath += ".npy"
return fpath
def get_npy(self, fname):
fpath = self._get_data_fpath(fname)
return np.load(fpath)
class MkDataHelper:
def __init__(self, data_dir="/data/home/kiokaxiao/data"):
self.data_dir = data_dir
def mkdata(self, subdir, name, shape, dtype=torch.float16):
outdir = os.path.join(self.data_dir, subdir)
os.makedirs(outdir, exist_ok=True)
fpath = os.path.join(outdir, name+".npy")
data = torch.randn(shape, dtype=torch.float16)
np.save(fpath, data.to(dtype).numpy())
return data
def gen_out_with_func(self, func, inputs):
output = func(inputs)
return output
def savedata(self, subdir, name, data):
outdir = os.path.join(self.data_dir, subdir)
os.makedirs(outdir, exist_ok=True)
fpath = os.path.join(outdir, name+".npy")
np.save(fpath, data.cpu().numpy())
class TorchSaver:
def __init__(self, data_dir):
self.data_dir = data_dir
os.makedirs(self.data_dir, exist_ok=True)
self.is_save = True
def save_v(self, name, v):
if not self.is_save:
return
fpath = os.path.join(self.data_dir, name+"_1.npy")
np.save(fpath, v.detach().cpu().numpy())
def save_v2(self, name, v):
if not self.is_save:
return
fpath = os.path.join(self.data_dir, name+"_1.npy")
np.save(fpath, v.detach().cpu().numpy())
def timer_annoc(funct):
def inner(*args,**kwargs):
start = time.perf_counter()
res = funct(*args,**kwargs)
torch.cuda.synchronize()
end = time.perf_counter()
print("torch cost: ", end-start)
return res
return inner
def get_mem_use():
f = os.popen("nvidia-smi | grep MiB" )
line = f.read().strip()
while " " in line:
line = line.replace(" ", " ")
memuse = line.split(" ")[8]
return memuse
if __name__ == "__main__":
dir_data = sys.argv[1]
fname_v1 = sys.argv[2]
fname_v2 = sys.argv[3]
tol = 0.01
if len(sys.argv) > 4:
tol = float(sys.argv[4])
checker = LyraChecker(dir_data, tol)
checker.cmp(fname_v1, fname_v2) |