from functools import partial import os print("Starting") import torch print(f"Is CUDA available: {torch.cuda.is_available()}") # print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") import numpy as np import gradio as gr import random import shutil import os from os.path import join as pjoin import torch.nn.functional as F from models.mask_transformer.transformer import MaskTransformer, ResidualTransformer from models.vq.model import RVQVAE, LengthEstimator from options.hgdemo_option import EvalT2MOptions from utils.get_opt import get_opt from utils.fixseed import fixseed from visualization.joints2bvh import Joint2BVHConvertor from torch.distributions.categorical import Categorical from utils.motion_process import recover_from_ric from utils.plot_script import plot_3d_motion from utils.paramUtil import t2m_kinematic_chain from gen_t2m import load_vq_model, load_res_model, load_trans_model, load_len_estimator clip_version = 'ViT-B/32' WEBSITE = """
""" WEBSITE_bottom = """ """ EXAMPLES = [ "A person is running on a treadmill.", "The person takes 4 steps backwards.", "A person jumps up and then lands.", "The person was pushed but did not fall.", "The person does a salsa dance.", "A figure streches it hands and arms above its head.", "This person kicks with his right leg then jabs several times.", "A person stands for few seconds and picks up his arms and shakes them.", "A person walks in a clockwise circle and stops where he began.", "A man bends down and picks something up with his right hand.", "A person walks with a limp, their left leg gets injured.", "A person repeatedly blocks their face with their right arm.", # "The person holds his left foot with his left hand, puts his right foot up and left hand up too.", "The person holds their left foot with their left hand, lifting both their left foot and left hand up.", # "A person stands, crosses left leg in front of the right, lowering themselves until they are sitting, both hands on the floor before standing and uncrossing legs.", "The person stands, crosses their left leg in front of the right, lowers themselves until they are sitting with both hands on the floor, and then stands back up, uncrossing their legs.", "The man walked forward, spun right on one foot and walked back to his original position.", "A man is walking forward then steps over an object then continues walking forward.", ] # Show closest text in the training # css to make videos look nice # var(--block-border-color); TODO CSS = """ .generate_video { position: relative; margin-left: auto; margin-right: auto; box-shadow: var(--block-shadow); border-width: var(--block-border-width); border-color: #000000; border-radius: var(--block-radius); background: var(--block-background-fill); width: 25%; line-height: var(--line-sm); } } """ DEFAULT_TEXT = "A person is " if not os.path.exists("/data/checkpoints/t2m"): os.system("bash prepare/download_models_demo.sh") if not os.path.exists("checkpoints/t2m"): os.system("ln -s /data/checkpoints checkpoints") if not os.path.exists("/data/stats"): os.makedirs("/data/stats") with open("/data/stats/Prompts.text", 'w') as f: pass Total_Calls = 4730 def update_total_calls(): global Total_Calls Total_Calls_offset = 4730 ## init number from visit, 01/07 with open("/data/stats/Prompts.text", 'r') as f: Total_Calls = len(f.readlines()) + Total_Calls_offset print("Prompts Num:",Total_Calls) ### Load Stats ### ########################## ######Preparing demo###### ########################## parser = EvalT2MOptions() opt = parser.parse() fixseed(opt.seed) opt.device = torch.device("cpu" if opt.gpu_id == -1 else "cuda:" + str(opt.gpu_id)) dim_pose = 263 root_dir = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name) model_dir = pjoin(root_dir, 'model') model_opt_path = pjoin(root_dir, 'opt.txt') model_opt = get_opt(model_opt_path, device=opt.device) ######Loading RVQ###### vq_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'opt.txt') vq_opt = get_opt(vq_opt_path, device=opt.device) vq_opt.dim_pose = dim_pose vq_model, vq_opt = load_vq_model(vq_opt) model_opt.num_tokens = vq_opt.nb_code model_opt.num_quantizers = vq_opt.num_quantizers model_opt.code_dim = vq_opt.code_dim ######Loading R-Transformer###### res_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.res_name, 'opt.txt') res_opt = get_opt(res_opt_path, device=opt.device) res_model = load_res_model(res_opt, vq_opt, opt) assert res_opt.vq_name == model_opt.vq_name ######Loading M-Transformer###### t2m_transformer = load_trans_model(model_opt, opt, 'latest.tar') #####Loading Length Predictor##### length_estimator = load_len_estimator(model_opt) t2m_transformer.eval() vq_model.eval() res_model.eval() length_estimator.eval() res_model.to(opt.device) t2m_transformer.to(opt.device) vq_model.to(opt.device) length_estimator.to(opt.device) opt.nb_joints = 22 mean = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'mean.npy')) std = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'std.npy')) def inv_transform(data): return data * std + mean kinematic_chain = t2m_kinematic_chain converter = Joint2BVHConvertor() cached_dir = './cached' uid = 12138 animation_path = pjoin(cached_dir, f'{uid}') os.makedirs(animation_path, exist_ok=True) @torch.no_grad() def generate( text, uid, motion_length=0, use_ik=True, seed=10107, repeat_times=1, ): # fixseed(seed) print(text) with open("/data/stats/Prompts.text", 'a') as f: f.write(text+'\n') update_total_calls() prompt_list = [] length_list = [] est_length = False prompt_list.append(text) if motion_length == 0: est_length = True else: length_list.append(motion_length) if est_length: print("Since no motion length are specified, we will use estimated motion lengthes!!") text_embedding = t2m_transformer.encode_text(prompt_list) pred_dis = length_estimator(text_embedding) probs = F.softmax(pred_dis, dim=-1) # (b, ntoken) token_lens = Categorical(probs).sample() # (b, seqlen) else: token_lens = torch.LongTensor(length_list) // 4 token_lens = token_lens.to(opt.device).long() m_length = token_lens * 4 captions = prompt_list datas = [] for r in range(repeat_times): mids = t2m_transformer.generate(captions, token_lens, timesteps=opt.time_steps, cond_scale=opt.cond_scale, temperature=opt.temperature, topk_filter_thres=opt.topkr, gsample=opt.gumbel_sample) mids = res_model.generate(mids, captions, token_lens, temperature=1, cond_scale=5) pred_motions = vq_model.forward_decoder(mids) pred_motions = pred_motions.detach().cpu().numpy() data = inv_transform(pred_motions) ruid = random.randrange(999999999) for k, (caption, joint_data) in enumerate(zip(captions, data)): animation_path = pjoin(cached_dir, f'{uid}') os.makedirs(animation_path, exist_ok=True) joint_data = joint_data[:m_length[k]] joint = recover_from_ric(torch.from_numpy(joint_data).float(), 22).numpy() bvh_path = pjoin(animation_path, "sample_repeat%d.bvh" % (r)) save_path = pjoin(animation_path, "sample_repeat%d_%d.mp4"%(r, ruid)) if use_ik: print("Using IK") _, joint = converter.convert(joint, filename=bvh_path, iterations=100) else: _, joint = converter.convert(joint, filename=bvh_path, iterations=100, foot_ik=False) plot_3d_motion(save_path, kinematic_chain, joint, title=caption, fps=20) np.save(pjoin(animation_path, "sample_repeat%d.npy"%(r)), joint) data_unit = { "url": pjoin(animation_path, "sample_repeat%d_%d.mp4"%(r, ruid)) } datas.append(data_unit) return datas # HTML component def get_video_html(data, video_id, width=700, height=700): url = data["url"] # class="wrap default svelte-gjihhp hide" #