|
import torch |
|
import numpy as np |
|
import os, sys |
|
import time |
|
|
|
class LyraChecker: |
|
def __init__(self, dir_data, tol): |
|
self.dir_data = dir_data |
|
self.tol = tol |
|
|
|
def cmp(self, fpath1, fpath2="", tol=0): |
|
tolbk = self.tol |
|
if tol != 0: |
|
self.tol = tol |
|
if fpath2 == "": |
|
fpath2 = fpath1 |
|
fpath1 += "_1" |
|
fpath2 += "_2" |
|
v1 = self.get_npy(fpath1) |
|
v2 = self.get_npy(fpath2) |
|
name = fpath1 |
|
if ".npy" in fpath1: |
|
name = ".".join(os.path.basename(fpath1).split(".")[:-1]) |
|
self._cmp_inner(v1, v2, name) |
|
self.tol = tolbk |
|
|
|
def _cmp_inner(self, v1, v2, name): |
|
print(v1.shape, v2.shape) |
|
if v1.shape != v2.shape: |
|
if v1.shape[1] == v2.shape[1]: |
|
v2 = v2.reshape([v2.shape[0], v2.shape[1], -1]) |
|
else: |
|
v2 = torch.tensor(v2).permute(0, 3, 1, 2).numpy() |
|
print(v1.shape, v2.shape) |
|
self._check_data(name, v1, v2) |
|
print(np.size(v1)) |
|
|
|
def _check_data(self, stage, x_out, x_gt): |
|
print(f"========== {stage} =============") |
|
print(x_out.shape, x_gt.shape) |
|
if np.allclose(x_gt, x_out, atol=self.tol): |
|
print(f"[OK] At {stage}, tol: {self.tol}") |
|
else: |
|
diff_cnt = np.count_nonzero(np.abs(x_gt - x_out)>self.tol) |
|
print(f"[FAIL]At {stage}, not aligned. tol: {self.tol}") |
|
print(" [INFO]Max diff: ", np.max(np.abs(x_gt - x_out))) |
|
print(" [INFO]Diff count: ", diff_cnt, ", ratio: ", round(diff_cnt/np.size(x_out), 2)) |
|
print(f">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>") |
|
|
|
|
|
def cmp_query(self, fpath1, fpath2): |
|
v1 = np.load(os.path.join(self.dir_data, fpath1)) |
|
vk = np.load(os.path.join(self.dir_data, fpath1).replace("query", "key")) |
|
vv = np.load(os.path.join(self.dir_data, fpath1).replace("query", "value")) |
|
|
|
v2 = np.load(os.path.join(self.dir_data, fpath2)) |
|
|
|
q2 = v2[:,:,0,:,:].transpose([0,2,1,3]) |
|
|
|
self.check_data("query", v1, q2) |
|
|
|
k2 = v2[:,:,1,:,:].transpose([0,2,1,3]) |
|
self.check_data("key", vk, k2) |
|
vv2 = v2[:,:,2,:,:].transpose([0,2,1,3]) |
|
|
|
self.check_data("value", vv, vv2) |
|
|
|
def _get_data_fpath(self, fname): |
|
fpath = os.path.join(self.dir_data, fname) |
|
if not fpath.endswith(".npy"): |
|
fpath += ".npy" |
|
return fpath |
|
|
|
def get_npy(self, fname): |
|
fpath = self._get_data_fpath(fname) |
|
return np.load(fpath) |
|
|
|
|
|
|
|
|
|
class MkDataHelper: |
|
def __init__(self, data_dir="/data/home/kiokaxiao/data"): |
|
self.data_dir = data_dir |
|
|
|
def mkdata(self, subdir, name, shape, dtype=torch.float16): |
|
outdir = os.path.join(self.data_dir, subdir) |
|
os.makedirs(outdir, exist_ok=True) |
|
fpath = os.path.join(outdir, name+".npy") |
|
data = torch.randn(shape, dtype=torch.float16) |
|
np.save(fpath, data.to(dtype).numpy()) |
|
return data |
|
|
|
def gen_out_with_func(self, func, inputs): |
|
output = func(inputs) |
|
return output |
|
|
|
def savedata(self, subdir, name, data): |
|
outdir = os.path.join(self.data_dir, subdir) |
|
os.makedirs(outdir, exist_ok=True) |
|
fpath = os.path.join(outdir, name+".npy") |
|
np.save(fpath, data.cpu().numpy()) |
|
|
|
|
|
class TorchSaver: |
|
def __init__(self, data_dir): |
|
self.data_dir = data_dir |
|
os.makedirs(self.data_dir, exist_ok=True) |
|
self.is_save = True |
|
|
|
def save_v(self, name, v): |
|
if not self.is_save: |
|
return |
|
fpath = os.path.join(self.data_dir, name+"_1.npy") |
|
np.save(fpath, v.detach().cpu().numpy()) |
|
|
|
def save_v2(self, name, v): |
|
if not self.is_save: |
|
return |
|
fpath = os.path.join(self.data_dir, name+"_1.npy") |
|
np.save(fpath, v.detach().cpu().numpy()) |
|
|
|
def timer_annoc(funct): |
|
def inner(*args,**kwargs): |
|
start = time.perf_counter() |
|
res = funct(*args,**kwargs) |
|
torch.cuda.synchronize() |
|
end = time.perf_counter() |
|
print("torch cost: ", end-start) |
|
return res |
|
return inner |
|
|
|
def get_mem_use(): |
|
f = os.popen("nvidia-smi | grep MiB" ) |
|
line = f.read().strip() |
|
while " " in line: |
|
line = line.replace(" ", " ") |
|
memuse = line.split(" ")[8] |
|
return memuse |
|
|
|
if __name__ == "__main__": |
|
dir_data = sys.argv[1] |
|
fname_v1 = sys.argv[2] |
|
fname_v2 = sys.argv[3] |
|
tol = 0.01 |
|
if len(sys.argv) > 4: |
|
tol = float(sys.argv[4]) |
|
checker = LyraChecker(dir_data, tol) |
|
checker.cmp(fname_v1, fname_v2) |