Spaces:
Running
on
T4
Running
on
T4
File size: 1,386 Bytes
53625b9 06bba7e 53625b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from .ppat_rgb import Projected, PointPatchTransformer
def module(state_dict: dict, name):
return {'.'.join(k.split('.')[1:]): v for k, v in state_dict.items() if k.startswith(name + '.')}
def G14(s):
model = Projected(
PointPatchTransformer(512, 12, 8, 512*3, 256, 384, 0.2, 64, 6),
nn.Linear(512, 1280)
)
model.load_state_dict(module(s, 'module'))
return model
def L14(s):
model = Projected(
PointPatchTransformer(512, 12, 8, 1024, 128, 64, 0.4, 256, 6),
nn.Linear(512, 768)
)
model.load_state_dict(module(s, 'pc_encoder'))
return model
def B32(s):
model = PointPatchTransformer(512, 12, 8, 1024, 128, 64, 0.4, 256, 6)
model.load_state_dict(module(s, 'pc_encoder'))
return model
model_list = {
"openshape-pointbert-vitb32-rgb": B32,
"openshape-pointbert-vitl14-rgb": L14,
"openshape-pointbert-vitg14-rgb": G14,
}
def load_pc_encoder(name):
s = torch.load(hf_hub_download("OpenShape/" + name, "model.pt", token=True), map_location='cpu')
model = model_list[name](s).eval()
if torch.cuda.is_available():
model.cuda()
return model
# only import the functions in demo!
from .sd_pc2img import pc_to_image
from .caption import pc_caption
from .classification import pred_lvis_sims
|