LASA / process_scripts /extract_img_vit_features.py
HaolinLiu's picture
first commit of codes and update readme.md
cc9780d
import os,sys
sys.path.append("..")
from util.simple_image_loader import Image_dataset
from torch.utils.data import DataLoader
import timm
import torch
from tqdm import tqdm
import numpy as np
from transformers import DPTForDepthEstimation, DPTFeatureExtractor
import argparse
from util import misc
from datasets.taxonomy import synthetic_arkit_category_combined
parser=argparse.ArgumentParser()
parser.add_argument("--category",nargs="+",type=str)
parser.add_argument("--root_dir",type=str, default="../data")
parser.add_argument("--ckpt_path",type=str,default="../open_clip_pytorch_model.bin")
parser.add_argument("--batch_size",type=int,default=24)
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
args= parser.parse_args()
misc.init_distributed_mode(args)
category=args.category
#dataset=Image_dataset(categories=['03001627','ABO_chair','future_chair'])
if args.category[0]=="all":
category=synthetic_arkit_category_combined["all"]
print("loading dataset")
dataset=Image_dataset(dataset_folder=args.root_dir,categories=category,n_px=224)
num_tasks = misc.get_world_size()
global_rank = misc.get_rank()
sampler = torch.utils.data.DistributedSampler(
dataset, num_replicas=num_tasks, rank=global_rank,
shuffle=False) # shuffle=True to reduce monitor bias
dataloader=DataLoader(
dataset,
sampler=sampler,
batch_size=args.batch_size,
num_workers=4,
pin_memory=True,
drop_last=False
)
print("loading model")
VIT_MODEL = 'vit_huge_patch14_224_clip_laion2b'
model=timm.create_model(VIT_MODEL, pretrained=True,pretrained_cfg_overlay=dict(file=args.ckpt_path))
model=model.eval().float().cuda()
save_dir=os.path.join(args.root_dir,"other_data")
for idx,data_batch in enumerate(dataloader):
if idx%50==0:
print("{}/{}".format(dataloader.__len__(),idx))
images = data_batch["images"].cuda().float()
model_id= data_batch["model_id"]
image_name=data_batch["image_name"]
category=data_batch["category"]
with torch.no_grad():
#output=model(images,output_hidden_states=True)
output_features=model.forward_features(images)
#predict_depth=output.predicted_depth
#print(predict_depth.shape)
for j in range(output_features.shape[0]):
save_folder=os.path.join(save_dir,category[j],"7_img_features",model_id[j])
os.makedirs(save_folder,exist_ok=True)
save_path=os.path.join(save_folder,image_name[j]+".npz")
#print("saving to",save_path)
np.savez_compressed(save_path,img_features=output_features[j].detach().cpu().numpy().astype(np.float32))