Spaces:
Runtime error
Runtime error
yuanze
commited on
Commit
•
f15a1cd
1
Parent(s):
6dc0a5f
init
Browse files- .gitignore +2 -0
- README.md +3 -3
- app.py +129 -0
- change_setup.txt +38 -0
- data/objaverse_uni3d_3D_embeddings.pt +3 -0
- data/objaverse_uni3d_image_above_embeddings.pt +3 -0
- data/objaverse_uni3d_image_back_embeddings.pt +3 -0
- data/objaverse_uni3d_image_below_embeddings.pt +3 -0
- data/objaverse_uni3d_image_diag_above_embeddings.pt +3 -0
- data/objaverse_uni3d_image_diag_below_embeddings.pt +3 -0
- data/objaverse_uni3d_image_front_embeddings.pt +3 -0
- data/objaverse_uni3d_image_left_embeddings.pt +3 -0
- data/objaverse_uni3d_image_right_embeddings.pt +3 -0
- data/objaverse_uni3d_text_embeddings.pt +3 -0
- data/source_id_list.pt +3 -0
- dockerfile +19 -0
- feature_extractors/__init__.py +56 -0
- feature_extractors/uni3d_embedding_encoder.py +337 -0
- packages +1 -0
- requirements.txt +9 -0
- utils/bpe_simple_vocab_16e6.txt.gz +3 -0
- utils/tokenizer.py +147 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.cache
|
2 |
+
__pycache__/
|
README.md
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
---
|
2 |
title: LD T3D
|
3 |
-
emoji:
|
4 |
colorFrom: indigo
|
5 |
colorTo: yellow
|
6 |
-
sdk:
|
7 |
-
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
|
|
1 |
---
|
2 |
title: LD T3D
|
3 |
+
emoji: 🐳
|
4 |
colorFrom: indigo
|
5 |
colorTo: yellow
|
6 |
+
sdk: docker
|
7 |
+
app_port: 7860
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
app.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import functools
|
6 |
+
from datasets import load_dataset
|
7 |
+
from feature_extractors.uni3d_embedding_encoder import Uni3dEmbeddingEncoder
|
8 |
+
|
9 |
+
# os.environ['HTTP_PROXY'] = 'http://192.168.48.17:18000'
|
10 |
+
# os.environ['HTTPS_PROXY'] = 'http://192.168.48.17:18000'
|
11 |
+
|
12 |
+
MAX_BATCH_SIZE = 16
|
13 |
+
MAX_QUEUE_SIZE = 10
|
14 |
+
MAX_K_RETRIEVAL = 20
|
15 |
+
cache_dir = "./.cache"
|
16 |
+
|
17 |
+
encoder = Uni3dEmbeddingEncoder(cache_dir)
|
18 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
19 |
+
source_id_list = torch.load("data/source_id_list.pt")
|
20 |
+
source_to_id = {source_id: i for i, source_id in enumerate(source_id_list)}
|
21 |
+
dataset = load_dataset("VAST-AI/LD-T3D", name=f"rendered_imgs_diag_above", split="base", cache_dir=cache_dir)
|
22 |
+
|
23 |
+
@functools.lru_cache()
|
24 |
+
def get_embedding(option, modality, angle=None):
|
25 |
+
save_path = f'data/objaverse_{option}_{modality + (("_" + str(angle)) if angle is not None else "")}_embeddings.pt'
|
26 |
+
if os.path.exists(save_path):
|
27 |
+
return torch.load(save_path)
|
28 |
+
else:
|
29 |
+
return gr.Error(f"Embedding file not found: {save_path}")
|
30 |
+
|
31 |
+
def predict(xb, xq, top_k):
|
32 |
+
xb = xb.to(xq.device)
|
33 |
+
sim = xq @ xb.T # (nq, nb)
|
34 |
+
_, indices = sim.topk(k=top_k, largest=True)
|
35 |
+
return indices
|
36 |
+
|
37 |
+
def get_image(index):
|
38 |
+
return dataset[index]["image"]
|
39 |
+
|
40 |
+
def retrieve_3D_models(textual_query, top_k, modality_list):
|
41 |
+
if textual_query == "":
|
42 |
+
raise gr.Error("Please enter a textual query")
|
43 |
+
if len(textual_query.split()) > 20:
|
44 |
+
gr.Warning("Retrieval result may be inaccurate due to long textual query")
|
45 |
+
if len(modality_list) == 0:
|
46 |
+
raise gr.Error("Please select at least one modality")
|
47 |
+
|
48 |
+
def _retrieve_3D_models(query, top_k, modals:list):
|
49 |
+
option = "uni3d"
|
50 |
+
op = "add"
|
51 |
+
is_text = True if "text" in modals else False
|
52 |
+
is_3D = True if "3D" in modals else False
|
53 |
+
if is_text:
|
54 |
+
modals.remove("text")
|
55 |
+
if is_3D:
|
56 |
+
modals.remove("3D")
|
57 |
+
angles = modals
|
58 |
+
|
59 |
+
# get base embeddings
|
60 |
+
embeddings = []
|
61 |
+
if is_text:
|
62 |
+
embeddings.append(get_embedding(option, "text"))
|
63 |
+
if len(angles) > 0:
|
64 |
+
for angle in angles:
|
65 |
+
embeddings.append(get_embedding(option, "image", angle=angle))
|
66 |
+
if is_3D:
|
67 |
+
embeddings.append(get_embedding(option, "3D"))
|
68 |
+
|
69 |
+
## fuse base embeddings
|
70 |
+
if len(embeddings) > 1:
|
71 |
+
if op == "concat":
|
72 |
+
embeddings = torch.cat(embeddings, dim=-1)
|
73 |
+
elif op == "add":
|
74 |
+
embeddings = sum(embeddings)
|
75 |
+
else:
|
76 |
+
raise ValueError(f"Unsupported operation: {op}")
|
77 |
+
embeddings /= embeddings.norm(dim=-1, keepdim=True)
|
78 |
+
else:
|
79 |
+
embeddings = embeddings[0]
|
80 |
+
|
81 |
+
# encode query embeddings
|
82 |
+
xq = encoder.encode_query(query)
|
83 |
+
if op == "concat":
|
84 |
+
xq = xq.repeat(1, embeddings.shape[-1] // xq.shape[-1]) # repeat to be aligned with the xb
|
85 |
+
xq /= xq.norm(dim=-1, keepdim=True)
|
86 |
+
|
87 |
+
pred_ind_list = predict(embeddings, xq, top_k)
|
88 |
+
return pred_ind_list[0].cpu().tolist() # we have only one query
|
89 |
+
|
90 |
+
indices = _retrieve_3D_models(textual_query, top_k, modality_list)
|
91 |
+
return [get_image(index) for index in indices]
|
92 |
+
|
93 |
+
def launch():
|
94 |
+
with gr.Blocks() as demo:
|
95 |
+
with gr.Row():
|
96 |
+
textual_query = gr.Textbox(label="Textual Query", autofocus=True,
|
97 |
+
placeholder="A chair with a wooden frame and a cushioned seat")
|
98 |
+
modality_list = gr.CheckboxGroup(label="Modality List", value=[],
|
99 |
+
choices=["text", "front", "back", "left", "right", "above",
|
100 |
+
"below", "diag_above", "diag_below", "3D"])
|
101 |
+
with gr.Row():
|
102 |
+
top_k = gr.Slider(minimum=1, maximum=MAX_K_RETRIEVAL, step=1, label="Top K Retrieval Result",
|
103 |
+
value=5, scale=2)
|
104 |
+
run = gr.Button("Search", scale=1)
|
105 |
+
clear_button = gr.ClearButton(scale=1)
|
106 |
+
with gr.Row():
|
107 |
+
output = gr.Gallery(format="webp", label="Retrieval Result", columns=5, type="pil")
|
108 |
+
run.click(retrieve_3D_models, [textual_query, top_k, modality_list], output,
|
109 |
+
# batch=True, max_batch_size=MAX_BATCH_SIZE
|
110 |
+
)
|
111 |
+
clear_button.click(lambda: ["", 5, [], []], outputs=[textual_query, top_k, modality_list, output])
|
112 |
+
examples = gr.Examples(examples=[["An ice cream with a cherry on top", 10, ["text", "front", "back", "left", "right", "above", "below", "diag_above", "diag_below", "3D"]],
|
113 |
+
["A mid-age castle", 10, ["text", "front", "back", "left", "right", "above", "below", "diag_above", "diag_below", "3D"]],
|
114 |
+
["A coke", 10, ["text", "front", "back", "left", "right", "above", "below", "diag_above", "diag_below", "3D"]]],
|
115 |
+
inputs=[textual_query, top_k, modality_list],
|
116 |
+
# cache_examples=True,
|
117 |
+
outputs=output,
|
118 |
+
fn=retrieve_3D_models)
|
119 |
+
|
120 |
+
demo.queue(max_size=10)
|
121 |
+
|
122 |
+
# os.environ.pop('HTTP_PROXY')
|
123 |
+
# os.environ.pop('HTTPS_PROXY')
|
124 |
+
|
125 |
+
demo.launch(server_name='0.0.0.0')
|
126 |
+
|
127 |
+
if __name__ == "__main__":
|
128 |
+
launch()
|
129 |
+
# print(len(retrieve_3D_models("A chair with a wooden frame and a cushioned seat", 5, ["3D", "diag_above", "diag_below"])))
|
change_setup.txt
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import os.path as osp
|
4 |
+
|
5 |
+
from setuptools import find_packages, setup
|
6 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
7 |
+
|
8 |
+
this_dir = osp.dirname(osp.abspath(__file__))
|
9 |
+
_ext_src_root = osp.join("pointnet2_ops", "_ext-src")
|
10 |
+
_ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
|
11 |
+
osp.join(_ext_src_root, "src", "*.cu")
|
12 |
+
)
|
13 |
+
_ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
|
14 |
+
|
15 |
+
requirements = ["torch>=1.4"]
|
16 |
+
|
17 |
+
exec(open(osp.join("pointnet2_ops", "_version.py")).read())
|
18 |
+
|
19 |
+
setup(
|
20 |
+
name="pointnet2_ops",
|
21 |
+
version=__version__,
|
22 |
+
author="Erik Wijmans",
|
23 |
+
packages=find_packages(),
|
24 |
+
install_requires=requirements,
|
25 |
+
ext_modules=[
|
26 |
+
CUDAExtension(
|
27 |
+
name="pointnet2_ops._ext",
|
28 |
+
sources=_ext_sources,
|
29 |
+
extra_compile_args={
|
30 |
+
"cxx": ["-O3"],
|
31 |
+
"nvcc": ["-O3", "-Xfatbin", "-compress-all"],
|
32 |
+
},
|
33 |
+
include_dirs=[osp.join(this_dir, _ext_src_root, "include")],
|
34 |
+
)
|
35 |
+
],
|
36 |
+
cmdclass={"build_ext": BuildExtension},
|
37 |
+
include_package_data=True,
|
38 |
+
)
|
data/objaverse_uni3d_3D_embeddings.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b05400ab75009785535bd78d859db0a902176fbeb5df2ef73e55a95990ded1b8
|
3 |
+
size 365511995
|
data/objaverse_uni3d_image_above_embeddings.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c0708d9bfb4df4e6f86a21bd5a1096401c8c037e84575e6d0397efdb1b138289
|
3 |
+
size 365512104
|
data/objaverse_uni3d_image_back_embeddings.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5667981bc1215e1f60c034ff8e2d214da6186a2f3212061b8ed3e1c32073ad6e
|
3 |
+
size 365512104
|
data/objaverse_uni3d_image_below_embeddings.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f91df0329424657666dd9a5b3181d52f9155ad545dc22a2f725f24f9b854abbd
|
3 |
+
size 365512104
|
data/objaverse_uni3d_image_diag_above_embeddings.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b44e2ee38885128e9080c75ee1d311fee8f718375e867c2209273649455c89a7
|
3 |
+
size 365512035
|
data/objaverse_uni3d_image_diag_below_embeddings.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:79eb0da600d75874e22bbfcca6001669eb14f06ec37326bf5148521db82f3e34
|
3 |
+
size 365512035
|
data/objaverse_uni3d_image_front_embeddings.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:016208fa7a76e959840c128c30e178a0b43a570cf7a8e6cfd6fcdb442f6b72db
|
3 |
+
size 365512104
|
data/objaverse_uni3d_image_left_embeddings.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c5db0c17a56ebbb0fa1323b105dfe04386f8d7f88c876bc24b943e8713a01076
|
3 |
+
size 365512035
|
data/objaverse_uni3d_image_right_embeddings.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f5fb149475c79b465157d5b2cfe2af4ad8947ff23f99577da264c2632bc9d770
|
3 |
+
size 365512035
|
data/objaverse_uni3d_text_embeddings.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a2d908630bcc8a5a231e8b5d11714c63a3e8b6d78427a82a833da9219b2a7263
|
3 |
+
size 365512020
|
data/source_id_list.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c218ccb58d0045b0b6671c1378ee43362054b890f9895d7cac3de727683a9a76
|
3 |
+
size 3747900
|
dockerfile
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
FROM nvcr.io/nvidia/pytorch:23.08
|
3 |
+
|
4 |
+
LABEL maintainer="yuanze"
|
5 |
+
LABEL email="yuanze1024@gmail.com"
|
6 |
+
|
7 |
+
# Install webp support
|
8 |
+
RUN apt update && apt install libwebp-dev -y
|
9 |
+
|
10 |
+
RUN pip install -r requirements.txt
|
11 |
+
|
12 |
+
# note that you may need to modify the TORCH_CUDA_ARCH_LIST in the setup.py file
|
13 |
+
ENV TORCH_CUDA_ARCH_LIST="8.6"
|
14 |
+
|
15 |
+
# Install Pointnet2_PyTorch
|
16 |
+
RUN git clone https://github.com/erikwijmans/Pointnet2_PyTorch.git \
|
17 |
+
&& mv -f backup_install.txt Pointnet2_PyTorch/pointnet2_ops_lib/setup.py \
|
18 |
+
&& cd Pointnet2_PyTorch/pointnet2_ops_lib \
|
19 |
+
&& python install .
|
feature_extractors/__init__.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections.abc import Sequence
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
import torch
|
4 |
+
from PIL.Image import Image
|
5 |
+
|
6 |
+
class FeatureExtractor(ABC):
|
7 |
+
@abstractmethod
|
8 |
+
def encode_image(self, img_list: Sequence[Image]) -> torch.Tensor:
|
9 |
+
"""
|
10 |
+
Encode the input images and return the corresponding embeddings.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
img_list: A list of PIL.Image.Image objects.
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
The embeddings of the input images. The shape should be (len(img_list), embedding_dim).
|
17 |
+
"""
|
18 |
+
raise NotImplementedError
|
19 |
+
|
20 |
+
@abstractmethod
|
21 |
+
def encode_text(self, text_list: Sequence[str]) -> torch.Tensor:
|
22 |
+
"""
|
23 |
+
Encode the input text data and return the corresponding embeddings.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
text_list: A list of strings.
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
The embeddings of the input text data. The shape should be (len(text_list), embedding_dim).
|
30 |
+
"""
|
31 |
+
raise NotImplementedError
|
32 |
+
|
33 |
+
@abstractmethod
|
34 |
+
def encode_3D(self, pc_tensor: torch.Tensor) -> torch.Tensor:
|
35 |
+
"""
|
36 |
+
Encode the input 3D point cloud and return the corresponding embeddings.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
pc_tensor: A tensor of shape (B, N, 3 + 3).
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
The embeddings of the input 3D point cloud. The shape should be (B, embedding_dim).
|
43 |
+
"""
|
44 |
+
raise NotImplementedError
|
45 |
+
|
46 |
+
@abstractmethod
|
47 |
+
def encode_query(self, queries: Sequence[str]) -> torch.Tensor:
|
48 |
+
"""Encode the queries and return the corresponding embeddings.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
queries: A list of strings.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
The embeddings of the input text data. The shape should be (len(input_text), embedding_dim).
|
55 |
+
"""
|
56 |
+
raise NotImplementedError
|
feature_extractors/uni3d_embedding_encoder.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
See https://github.com/baaivision/Uni3D for source code
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import timm
|
8 |
+
import numpy as np
|
9 |
+
from pointnet2_ops import pointnet2_utils
|
10 |
+
import open_clip
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
+
import sys
|
13 |
+
sys.path.append('')
|
14 |
+
from feature_extractors import FeatureExtractor
|
15 |
+
from utils.tokenizer import SimpleTokenizer
|
16 |
+
|
17 |
+
import logging
|
18 |
+
|
19 |
+
def fps(data, number):
|
20 |
+
'''
|
21 |
+
data B N 3
|
22 |
+
number int
|
23 |
+
'''
|
24 |
+
fps_idx = pointnet2_utils.furthest_point_sample(data, number)
|
25 |
+
fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous()
|
26 |
+
return fps_data
|
27 |
+
|
28 |
+
# https://github.com/Strawberry-Eat-Mango/PCT_Pytorch/blob/main/util.py
|
29 |
+
def knn_point(nsample, xyz, new_xyz):
|
30 |
+
"""
|
31 |
+
Input:
|
32 |
+
nsample: max sample number in local region
|
33 |
+
xyz: all points, [B, N, C]
|
34 |
+
new_xyz: query points, [B, S, C]
|
35 |
+
Return:
|
36 |
+
group_idx: grouped points index, [B, S, nsample]
|
37 |
+
"""
|
38 |
+
sqrdists = square_distance(new_xyz, xyz)
|
39 |
+
_, group_idx = torch.topk(sqrdists, nsample, dim = -1, largest=False, sorted=False)
|
40 |
+
return group_idx
|
41 |
+
|
42 |
+
def square_distance(src, dst):
|
43 |
+
"""
|
44 |
+
Calculate Euclid distance between each two points.
|
45 |
+
src^T * dst = xn * xm + yn * ym + zn * zm;
|
46 |
+
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
|
47 |
+
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
|
48 |
+
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
|
49 |
+
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
|
50 |
+
Input:
|
51 |
+
src: source points, [B, N, C]
|
52 |
+
dst: target points, [B, M, C]
|
53 |
+
Output:
|
54 |
+
dist: per-point square distance, [B, N, M]
|
55 |
+
"""
|
56 |
+
B, N, _ = src.shape
|
57 |
+
_, M, _ = dst.shape
|
58 |
+
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
|
59 |
+
dist += torch.sum(src ** 2, -1).view(B, N, 1)
|
60 |
+
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
|
61 |
+
return dist
|
62 |
+
|
63 |
+
|
64 |
+
class PatchDropout(nn.Module):
|
65 |
+
"""
|
66 |
+
https://arxiv.org/abs/2212.00794
|
67 |
+
"""
|
68 |
+
|
69 |
+
def __init__(self, prob, exclude_first_token=True):
|
70 |
+
super().__init__()
|
71 |
+
assert 0 <= prob < 1.
|
72 |
+
self.prob = prob
|
73 |
+
self.exclude_first_token = exclude_first_token # exclude CLS token
|
74 |
+
logging.info("patch dropout prob is {}".format(prob))
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
# if not self.training or self.prob == 0.:
|
78 |
+
# return x
|
79 |
+
|
80 |
+
if self.exclude_first_token:
|
81 |
+
cls_tokens, x = x[:, :1], x[:, 1:]
|
82 |
+
else:
|
83 |
+
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
|
84 |
+
|
85 |
+
batch = x.size()[0]
|
86 |
+
num_tokens = x.size()[1]
|
87 |
+
|
88 |
+
batch_indices = torch.arange(batch)
|
89 |
+
batch_indices = batch_indices[..., None]
|
90 |
+
|
91 |
+
keep_prob = 1 - self.prob
|
92 |
+
num_patches_keep = max(1, int(num_tokens * keep_prob))
|
93 |
+
|
94 |
+
rand = torch.randn(batch, num_tokens)
|
95 |
+
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
96 |
+
|
97 |
+
x = x[batch_indices, patch_indices_keep]
|
98 |
+
|
99 |
+
if self.exclude_first_token:
|
100 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
101 |
+
|
102 |
+
return x
|
103 |
+
|
104 |
+
|
105 |
+
class Group(nn.Module):
|
106 |
+
def __init__(self, num_group, group_size):
|
107 |
+
super().__init__()
|
108 |
+
self.num_group = num_group
|
109 |
+
self.group_size = group_size
|
110 |
+
|
111 |
+
def forward(self, xyz, color):
|
112 |
+
'''
|
113 |
+
input: B N 3
|
114 |
+
---------------------------
|
115 |
+
output: B G M 3
|
116 |
+
center : B G 3
|
117 |
+
'''
|
118 |
+
batch_size, num_points, _ = xyz.shape
|
119 |
+
# fps the centers out
|
120 |
+
center = fps(xyz, self.num_group) # B G 3
|
121 |
+
# knn to get the neighborhood
|
122 |
+
# _, idx = self.knn(xyz, center) # B G M
|
123 |
+
idx = knn_point(self.group_size, xyz, center) # B G M
|
124 |
+
assert idx.size(1) == self.num_group
|
125 |
+
assert idx.size(2) == self.group_size
|
126 |
+
idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
|
127 |
+
idx = idx + idx_base
|
128 |
+
idx = idx.view(-1)
|
129 |
+
neighborhood = xyz.view(batch_size * num_points, -1)[idx, :]
|
130 |
+
neighborhood = neighborhood.view(batch_size, self.num_group, self.group_size, 3).contiguous()
|
131 |
+
|
132 |
+
neighborhood_color = color.view(batch_size * num_points, -1)[idx, :]
|
133 |
+
neighborhood_color = neighborhood_color.view(batch_size, self.num_group, self.group_size, 3).contiguous()
|
134 |
+
|
135 |
+
# normalize
|
136 |
+
neighborhood = neighborhood - center.unsqueeze(2)
|
137 |
+
|
138 |
+
features = torch.cat((neighborhood, neighborhood_color), dim=-1)
|
139 |
+
return neighborhood, center, features
|
140 |
+
|
141 |
+
class Encoder(nn.Module):
|
142 |
+
def __init__(self, encoder_channel):
|
143 |
+
super().__init__()
|
144 |
+
self.encoder_channel = encoder_channel
|
145 |
+
self.first_conv = nn.Sequential(
|
146 |
+
nn.Conv1d(6, 128, 1),
|
147 |
+
nn.BatchNorm1d(128),
|
148 |
+
nn.ReLU(inplace=True),
|
149 |
+
nn.Conv1d(128, 256, 1)
|
150 |
+
)
|
151 |
+
self.second_conv = nn.Sequential(
|
152 |
+
nn.Conv1d(512, 512, 1),
|
153 |
+
nn.BatchNorm1d(512),
|
154 |
+
nn.ReLU(inplace=True),
|
155 |
+
nn.Conv1d(512, self.encoder_channel, 1)
|
156 |
+
)
|
157 |
+
def forward(self, point_groups):
|
158 |
+
'''
|
159 |
+
point_groups : B G N 3
|
160 |
+
-----------------
|
161 |
+
feature_global : B G C
|
162 |
+
'''
|
163 |
+
bs, g, n , _ = point_groups.shape
|
164 |
+
point_groups = point_groups.reshape(bs * g, n, 6)
|
165 |
+
# encoder
|
166 |
+
feature = self.first_conv(point_groups.transpose(2,1)) # BG 256 n
|
167 |
+
feature_global = torch.max(feature,dim=2,keepdim=True)[0] # BG 256 1
|
168 |
+
feature = torch.cat([feature_global.expand(-1,-1,n), feature], dim=1)# BG 512 n
|
169 |
+
feature = self.second_conv(feature) # BG 1024 n
|
170 |
+
feature_global = torch.max(feature, dim=2, keepdim=False)[0] # BG 1024
|
171 |
+
return feature_global.reshape(bs, g, self.encoder_channel)
|
172 |
+
|
173 |
+
class PointcloudEncoder(nn.Module):
|
174 |
+
def __init__(self, point_transformer):
|
175 |
+
# use the giant branch of uni3d
|
176 |
+
super().__init__()
|
177 |
+
from easydict import EasyDict
|
178 |
+
self.trans_dim = 1408
|
179 |
+
self.embed_dim = 1024
|
180 |
+
self.group_size = 64
|
181 |
+
self.num_group = 512
|
182 |
+
# grouper
|
183 |
+
self.group_divider = Group(num_group = self.num_group, group_size = self.group_size)
|
184 |
+
# define the encoder
|
185 |
+
self.encoder_dim = 512
|
186 |
+
self.encoder = Encoder(encoder_channel = self.encoder_dim)
|
187 |
+
|
188 |
+
# bridge encoder and transformer
|
189 |
+
self.encoder2trans = nn.Linear(self.encoder_dim, self.trans_dim)
|
190 |
+
|
191 |
+
# bridge transformer and clip embedding
|
192 |
+
self.trans2embed = nn.Linear(self.trans_dim, self.embed_dim)
|
193 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
|
194 |
+
self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim))
|
195 |
+
|
196 |
+
self.pos_embed = nn.Sequential(
|
197 |
+
nn.Linear(3, 128),
|
198 |
+
nn.GELU(),
|
199 |
+
nn.Linear(128, self.trans_dim)
|
200 |
+
)
|
201 |
+
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
|
202 |
+
self.patch_dropout = PatchDropout(0.) if 0. > 0. else nn.Identity()
|
203 |
+
self.visual = point_transformer
|
204 |
+
|
205 |
+
|
206 |
+
def forward(self, pts, colors):
|
207 |
+
# divide the point cloud in the same form. This is important
|
208 |
+
_, center, features = self.group_divider(pts, colors)
|
209 |
+
|
210 |
+
# encoder the input cloud patches
|
211 |
+
group_input_tokens = self.encoder(features) # B G N
|
212 |
+
group_input_tokens = self.encoder2trans(group_input_tokens)
|
213 |
+
# prepare cls
|
214 |
+
cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1)
|
215 |
+
cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1)
|
216 |
+
# add pos embedding
|
217 |
+
pos = self.pos_embed(center)
|
218 |
+
# final input
|
219 |
+
x = torch.cat((cls_tokens, group_input_tokens), dim=1)
|
220 |
+
pos = torch.cat((cls_pos, pos), dim=1)
|
221 |
+
# transformer
|
222 |
+
x = x + pos
|
223 |
+
# x = x.half()
|
224 |
+
|
225 |
+
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
|
226 |
+
x = self.patch_dropout(x)
|
227 |
+
|
228 |
+
x = self.visual.pos_drop(x)
|
229 |
+
|
230 |
+
# ModuleList not support forward
|
231 |
+
for i, blk in enumerate(self.visual.blocks):
|
232 |
+
x = blk(x)
|
233 |
+
x = self.visual.norm(x[:, 0, :])
|
234 |
+
x = self.visual.fc_norm(x)
|
235 |
+
|
236 |
+
x = self.trans2embed(x)
|
237 |
+
return x
|
238 |
+
|
239 |
+
class Uni3D(nn.Module):
|
240 |
+
def __init__(self, point_encoder):
|
241 |
+
super().__init__()
|
242 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
243 |
+
self.point_encoder = point_encoder
|
244 |
+
|
245 |
+
def encode_pc(self, pc):
|
246 |
+
xyz = pc[:,:,:3].contiguous()
|
247 |
+
color = pc[:,:,3:].contiguous()
|
248 |
+
pc_feat = self.point_encoder(xyz, color)
|
249 |
+
return pc_feat
|
250 |
+
|
251 |
+
def forward(self, pc, text, image):
|
252 |
+
text_embed_all = text
|
253 |
+
image_embed = image
|
254 |
+
pc_embed = self.encode_pc(pc)
|
255 |
+
return {'text_embed': text_embed_all,
|
256 |
+
'pc_embed': pc_embed,
|
257 |
+
'image_embed': image_embed,
|
258 |
+
'logit_scale': self.logit_scale.exp()}
|
259 |
+
|
260 |
+
def get_metric_names(model):
|
261 |
+
return ['loss', 'uni3d_loss', 'pc_image_acc', 'pc_text_acc']
|
262 |
+
|
263 |
+
def create_uni3d(uni3d_path):
|
264 |
+
# create transformer blocks for point cloud via timm
|
265 |
+
point_transformer = timm.create_model("eva_giant_patch14_560")
|
266 |
+
|
267 |
+
# create whole point cloud encoder
|
268 |
+
point_encoder = PointcloudEncoder(point_transformer)
|
269 |
+
|
270 |
+
# uni3d model
|
271 |
+
model = Uni3D(point_encoder=point_encoder,)
|
272 |
+
|
273 |
+
checkpoint = torch.load(uni3d_path, map_location='cpu')
|
274 |
+
logging.info('loaded checkpoint {}'.format(uni3d_path))
|
275 |
+
sd = checkpoint['module']
|
276 |
+
if next(iter(sd.items()))[0].startswith('module'):
|
277 |
+
sd = {k[len('module.'):]: v for k, v in sd.items()}
|
278 |
+
model.load_state_dict(sd)
|
279 |
+
return model
|
280 |
+
|
281 |
+
class Uni3dEmbeddingEncoder(FeatureExtractor):
|
282 |
+
def __init__(self, cache_dir, **kwargs) -> None:
|
283 |
+
bpe_path = "utils/bpe_simple_vocab_16e6.txt.gz"
|
284 |
+
uni3d_path = os.path.join(cache_dir, "Uni3D", "modelzoo", "uni3d-g", "model.pt") # concat the subfolder as hf_hub_download will put it here
|
285 |
+
clip_path = os.path.join(cache_dir, "Uni3D", "open_clip_pytorch_model.bin")
|
286 |
+
|
287 |
+
if not os.path.exists(uni3d_path):
|
288 |
+
hf_hub_download("BAAI/Uni3D", "model.pt", subfolder="modelzoo/uni3d-g", cache_dir=cache_dir,
|
289 |
+
local_dir=cache_dir + os.sep + "Uni3D")
|
290 |
+
if not os.path.exists(clip_path):
|
291 |
+
hf_hub_download("timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k", "open_clip_pytorch_model.bin",
|
292 |
+
cache_dir=cache_dir, local_dir=cache_dir + os.sep + "Uni3D")
|
293 |
+
|
294 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
295 |
+
self.tokenizer = SimpleTokenizer(bpe_path)
|
296 |
+
self.model = create_uni3d(uni3d_path)
|
297 |
+
self.model.eval()
|
298 |
+
self.model.to(self.device)
|
299 |
+
self.clip_model, _, self.preprocess = open_clip.create_model_and_transforms(model_name="EVA02-E-14-plus", pretrained=clip_path)
|
300 |
+
self.clip_model.to(self.device)
|
301 |
+
|
302 |
+
def pc_norm(self, pc):
|
303 |
+
""" pc: NxC, return NxC """
|
304 |
+
centroid = np.mean(pc, axis=0)
|
305 |
+
pc = pc - centroid
|
306 |
+
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
|
307 |
+
pc = pc / m
|
308 |
+
return pc
|
309 |
+
|
310 |
+
@torch.no_grad()
|
311 |
+
def encode_3D(self, data):
|
312 |
+
pc = data.to(device=self.device, non_blocking=True)
|
313 |
+
pc_features = self.model.encode_pc(pc)
|
314 |
+
pc_features = pc_features / pc_features.norm(dim=-1, keepdim=True)
|
315 |
+
return pc_features.float()
|
316 |
+
|
317 |
+
@torch.no_grad()
|
318 |
+
def encode_text(self, input_text):
|
319 |
+
texts = self.tokenizer(input_text).to(device=self.device, non_blocking=True)
|
320 |
+
if len(texts.shape) < 2:
|
321 |
+
texts = texts[None, ...]
|
322 |
+
class_embeddings = self.clip_model.encode_text(texts)
|
323 |
+
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
|
324 |
+
return class_embeddings.float()
|
325 |
+
|
326 |
+
@torch.no_grad()
|
327 |
+
def encode_image(self, img_tensor_list):
|
328 |
+
image = img_tensor_list.to(device=self.device, non_blocking=True)
|
329 |
+
image_features = self.clip_model.encode_image(image)
|
330 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
331 |
+
return image_features.float()
|
332 |
+
|
333 |
+
def encode_query(self, query_list):
|
334 |
+
return self.encode_text(query_list)
|
335 |
+
|
336 |
+
def get_img_transform(self):
|
337 |
+
return self.preprocess
|
packages
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
libwebp-dev
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
datasets
|
3 |
+
timm
|
4 |
+
pillow
|
5 |
+
open-clip-torch
|
6 |
+
huggingface_hub
|
7 |
+
ftfy
|
8 |
+
regex
|
9 |
+
easydict
|
utils/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
utils/tokenizer.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copied from github.com/baaivision/Uni3D
|
2 |
+
# # Modified from github.com/openai/CLIP
|
3 |
+
import gzip
|
4 |
+
import html
|
5 |
+
import os
|
6 |
+
from functools import lru_cache
|
7 |
+
|
8 |
+
import ftfy
|
9 |
+
import regex as re
|
10 |
+
import torch
|
11 |
+
|
12 |
+
|
13 |
+
@lru_cache()
|
14 |
+
def bytes_to_unicode():
|
15 |
+
"""
|
16 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
17 |
+
The reversible bpe codes work on unicode strings.
|
18 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
19 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
20 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
21 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
22 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
23 |
+
"""
|
24 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
25 |
+
cs = bs[:]
|
26 |
+
n = 0
|
27 |
+
for b in range(2**8):
|
28 |
+
if b not in bs:
|
29 |
+
bs.append(b)
|
30 |
+
cs.append(2**8+n)
|
31 |
+
n += 1
|
32 |
+
cs = [chr(n) for n in cs]
|
33 |
+
return dict(zip(bs, cs))
|
34 |
+
|
35 |
+
|
36 |
+
def get_pairs(word):
|
37 |
+
"""Return set of symbol pairs in a word.
|
38 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
39 |
+
"""
|
40 |
+
pairs = set()
|
41 |
+
prev_char = word[0]
|
42 |
+
for char in word[1:]:
|
43 |
+
pairs.add((prev_char, char))
|
44 |
+
prev_char = char
|
45 |
+
return pairs
|
46 |
+
|
47 |
+
|
48 |
+
def basic_clean(text):
|
49 |
+
text = ftfy.fix_text(text)
|
50 |
+
text = html.unescape(html.unescape(text))
|
51 |
+
return text.strip()
|
52 |
+
|
53 |
+
|
54 |
+
def whitespace_clean(text):
|
55 |
+
text = re.sub(r'\s+', ' ', text)
|
56 |
+
text = text.strip()
|
57 |
+
return text
|
58 |
+
|
59 |
+
|
60 |
+
class SimpleTokenizer(object):
|
61 |
+
def __init__(self, bpe_path):
|
62 |
+
self.byte_encoder = bytes_to_unicode()
|
63 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
64 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
65 |
+
merges = merges[1:49152-256-2+1]
|
66 |
+
merges = [tuple(merge.split()) for merge in merges]
|
67 |
+
vocab = list(bytes_to_unicode().values())
|
68 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
69 |
+
for merge in merges:
|
70 |
+
vocab.append(''.join(merge))
|
71 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
72 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
73 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
74 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
75 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
76 |
+
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
77 |
+
|
78 |
+
def bpe(self, token):
|
79 |
+
if token in self.cache:
|
80 |
+
return self.cache[token]
|
81 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
82 |
+
pairs = get_pairs(word)
|
83 |
+
|
84 |
+
if not pairs:
|
85 |
+
return token+'</w>'
|
86 |
+
|
87 |
+
while True:
|
88 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
89 |
+
if bigram not in self.bpe_ranks:
|
90 |
+
break
|
91 |
+
first, second = bigram
|
92 |
+
new_word = []
|
93 |
+
i = 0
|
94 |
+
while i < len(word):
|
95 |
+
try:
|
96 |
+
j = word.index(first, i)
|
97 |
+
new_word.extend(word[i:j])
|
98 |
+
i = j
|
99 |
+
except:
|
100 |
+
new_word.extend(word[i:])
|
101 |
+
break
|
102 |
+
|
103 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
104 |
+
new_word.append(first+second)
|
105 |
+
i += 2
|
106 |
+
else:
|
107 |
+
new_word.append(word[i])
|
108 |
+
i += 1
|
109 |
+
new_word = tuple(new_word)
|
110 |
+
word = new_word
|
111 |
+
if len(word) == 1:
|
112 |
+
break
|
113 |
+
else:
|
114 |
+
pairs = get_pairs(word)
|
115 |
+
word = ' '.join(word)
|
116 |
+
self.cache[token] = word
|
117 |
+
return word
|
118 |
+
|
119 |
+
def encode(self, text):
|
120 |
+
bpe_tokens = []
|
121 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
122 |
+
for token in re.findall(self.pat, text):
|
123 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
124 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
125 |
+
return bpe_tokens
|
126 |
+
|
127 |
+
def decode(self, tokens):
|
128 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
129 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
130 |
+
return text
|
131 |
+
|
132 |
+
def __call__(self, texts, context_length=77):
|
133 |
+
if isinstance(texts, str):
|
134 |
+
texts = [texts]
|
135 |
+
|
136 |
+
sot_token = self.encoder["<|startoftext|>"]
|
137 |
+
eot_token = self.encoder["<|endoftext|>"]
|
138 |
+
all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
|
139 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
140 |
+
|
141 |
+
for i, tokens in enumerate(all_tokens):
|
142 |
+
tokens = tokens[:context_length]
|
143 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
144 |
+
|
145 |
+
if len(result) == 1:
|
146 |
+
return result[0]
|
147 |
+
return result
|