eliphatfs commited on
Commit
e08783a
·
1 Parent(s): d154ca2
Files changed (3) hide show
  1. app.py +10 -23
  2. openshape/__init__.py +0 -1
  3. openshape/caption.py +0 -163
app.py CHANGED
@@ -76,28 +76,15 @@ def render_pc(pc):
76
 
77
 
78
  try:
79
- tab_cls, tab_cap = st.tabs(["Classification", "Point Cloud Captioning"])
80
-
81
- with tab_cls:
82
- if st.button("Run Classification on LVIS Categories"):
83
- pc = load_data()
84
- col2 = render_pc(pc)
85
- prog.progress(0.5, "Running Classification")
86
- pred = openshape.pred_lvis_sims(model_g14, pc)
87
- with col2:
88
- for i, (cat, sim) in zip(range(5), pred.items()):
89
- st.text(cat)
90
- st.caption("Similarity %.4f" % sim)
91
- prog.progress(1.0, "Idle")
92
-
93
- with tab_cap:
94
- cond_scale = st.slider('Conditioning Scale', 0.0, 4.0, 2.0)
95
- if st.button("Generate a Caption"):
96
- pc = load_data()
97
- col2 = render_pc(pc)
98
- prog.progress(0.5, "Running Generation")
99
- cap = openshape.pc_caption(model_b32, pc, cond_scale)
100
- st.text(cap)
101
- prog.progress(1.0, "Idle")
102
  except Exception as exc:
103
  st.error(repr(exc))
 
76
 
77
 
78
  try:
79
+ if st.button("Run Classification on LVIS Categories"):
80
+ pc = load_data()
81
+ col2 = render_pc(pc)
82
+ prog.progress(0.5, "Running Classification")
83
+ pred = openshape.pred_lvis_sims(model_g14, pc)
84
+ with col2:
85
+ for i, (cat, sim) in zip(range(5), pred.items()):
86
+ st.text(cat)
87
+ st.caption("Similarity %.4f" % sim)
88
+ prog.progress(1.0, "Idle")
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  except Exception as exc:
90
  st.error(repr(exc))
openshape/__init__.py CHANGED
@@ -49,5 +49,4 @@ def load_pc_encoder(name):
49
 
50
  # only import the functions in demo!
51
  # from .sd_pc2img import pc_to_image
52
- from .caption import pc_caption
53
  from .classification import pred_lvis_sims
 
49
 
50
  # only import the functions in demo!
51
  # from .sd_pc2img import pc_to_image
 
52
  from .classification import pred_lvis_sims
openshape/caption.py DELETED
@@ -1,163 +0,0 @@
1
- from torch import nn
2
- import numpy as np
3
- import torch
4
- from typing import Tuple, List, Union, Optional
5
- from transformers import GPT2Tokenizer, GPT2LMHeadModel
6
- from huggingface_hub import hf_hub_download
7
-
8
-
9
- N = type(None)
10
- V = np.array
11
- ARRAY = np.ndarray
12
- ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
13
- VS = Union[Tuple[V, ...], List[V]]
14
- VN = Union[V, N]
15
- VNS = Union[VS, N]
16
- T = torch.Tensor
17
- TS = Union[Tuple[T, ...], List[T]]
18
- TN = Optional[T]
19
- TNS = Union[Tuple[TN, ...], List[TN]]
20
- TSN = Optional[TS]
21
- TA = Union[T, ARRAY]
22
-
23
-
24
- D = torch.device
25
-
26
-
27
- class MLP(nn.Module):
28
-
29
- def forward(self, x: T) -> T:
30
- return self.model(x)
31
-
32
- def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
33
- super(MLP, self).__init__()
34
- layers = []
35
- for i in range(len(sizes) -1):
36
- layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
37
- if i < len(sizes) - 2:
38
- layers.append(act())
39
- self.model = nn.Sequential(*layers)
40
-
41
-
42
- class ClipCaptionModel(nn.Module):
43
-
44
- #@functools.lru_cache #FIXME
45
- def get_dummy_token(self, batch_size: int, device: D) -> T:
46
- return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
47
-
48
- def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None):
49
- embedding_text = self.gpt.transformer.wte(tokens)
50
- prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
51
- #print(embedding_text.size()) #torch.Size([5, 67, 768])
52
- #print(prefix_projections.size()) #torch.Size([5, 1, 768])
53
- embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
54
- if labels is not None:
55
- dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
56
- labels = torch.cat((dummy_token, tokens), dim=1)
57
- out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
58
- return out
59
-
60
- def __init__(self, prefix_length: int, prefix_size: int = 512):
61
- super(ClipCaptionModel, self).__init__()
62
- self.prefix_length = prefix_length
63
- self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
64
- self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
65
- if prefix_length > 10: # not enough memory
66
- self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
67
- else:
68
- self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))
69
-
70
-
71
- class ClipCaptionPrefix(ClipCaptionModel):
72
-
73
- def parameters(self, recurse: bool = True):
74
- return self.clip_project.parameters()
75
-
76
- def train(self, mode: bool = True):
77
- super(ClipCaptionPrefix, self).train(mode)
78
- self.gpt.eval()
79
- return self
80
-
81
-
82
- def generate2(
83
- model,
84
- tokenizer,
85
- tokens=None,
86
- prompt=None,
87
- embed=None,
88
- entry_count=1,
89
- entry_length=67, # maximum number of words
90
- top_p=0.8,
91
- temperature=1.,
92
- stop_token: str = '.',
93
- ):
94
- model.eval()
95
- generated_num = 0
96
- generated_list = []
97
- stop_token_index = tokenizer.encode(stop_token)[0]
98
- filter_value = -float("Inf")
99
- device = next(model.parameters()).device
100
- score_col = []
101
- with torch.no_grad():
102
-
103
- for entry_idx in range(entry_count):
104
- if embed is not None:
105
- generated = embed
106
- else:
107
- if tokens is None:
108
- tokens = torch.tensor(tokenizer.encode(prompt))
109
- tokens = tokens.unsqueeze(0).to(device)
110
-
111
- generated = model.gpt.transformer.wte(tokens)
112
-
113
- for i in range(entry_length):
114
-
115
- outputs = model.gpt(inputs_embeds=generated)
116
- logits = outputs.logits
117
- logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
118
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
119
- cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
120
- sorted_indices_to_remove = cumulative_probs > top_p
121
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
122
- ..., :-1
123
- ].clone()
124
- sorted_indices_to_remove[..., 0] = 0
125
-
126
- indices_to_remove = sorted_indices[sorted_indices_to_remove]
127
- logits[:, indices_to_remove] = filter_value
128
- next_token = torch.argmax(torch.softmax(logits, dim=-1), -1).reshape(1, 1)
129
- score = torch.softmax(logits, dim=-1).reshape(-1)[next_token.item()].item()
130
- score_col.append(score)
131
- next_token_embed = model.gpt.transformer.wte(next_token)
132
- if tokens is None:
133
- tokens = next_token
134
- else:
135
- tokens = torch.cat((tokens, next_token), dim=1)
136
- generated = torch.cat((generated, next_token_embed), dim=1)
137
- if stop_token_index == next_token.item():
138
- break
139
-
140
- output_list = list(tokens.squeeze(0).cpu().numpy())
141
- output_text = tokenizer.decode(output_list)
142
- generated_list.append(output_text)
143
- return generated_list[0]
144
-
145
-
146
- @torch.no_grad()
147
- def pc_caption(pc_encoder: torch.nn.Module, pc, cond_scale):
148
- ref_dev = next(pc_encoder.parameters()).device
149
- prefix = pc_encoder(torch.tensor(pc.T[None], device=ref_dev))
150
- prefix = prefix.float() * cond_scale
151
- prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
152
- text = generate2(model, tokenizer, embed=prefix_embed)
153
- return text
154
-
155
-
156
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
157
- prefix_length = 10
158
- model = ClipCaptionModel(prefix_length)
159
- # print(model.gpt_embedding_size)
160
- model.load_state_dict(torch.load(hf_hub_download('OpenShape/clipcap-cc', 'conceptual_weights.pt', token=True), map_location='cpu'))
161
- model.eval()
162
- if torch.cuda.is_available():
163
- model = model.cuda()