eliphatfs commited on
Commit
ff2e0a9
1 Parent(s): 92ce27f

Use support library.

Browse files
app.py CHANGED
@@ -1,14 +1,12 @@
1
  import streamlit as st
2
- from huggingface_hub import HfFolder
3
  HfFolder().save_token(st.secrets['etoken'])
 
4
 
5
 
6
  import numpy
7
- import trimesh
8
- import objaverse
9
  import openshape
10
- import misc_utils
11
- import plotly.graph_objects as go
12
 
13
 
14
  @st.cache_resource
@@ -17,70 +15,21 @@ def load_openshape(name):
17
 
18
 
19
  f32 = numpy.float32
20
- model_b32 = openshape.load_pc_encoder('openshape-pointbert-vitb32-rgb')
21
- model_l14 = openshape.load_pc_encoder('openshape-pointbert-vitl14-rgb')
22
  model_g14 = openshape.load_pc_encoder('openshape-pointbert-vitg14-rgb')
23
 
24
 
25
  st.title("OpenShape Demo")
26
- objaid = st.text_input("Enter an Objaverse ID")
27
- model = st.file_uploader("Or upload a model (.glb/.obj/.ply)")
28
- npy = st.file_uploader("Or upload a point cloud numpy array (.npy of Nx3 XYZ or Nx6 XYZRGB)")
29
- swap_yz_axes = st.checkbox("Swap Y/Z axes of input (Y is up for OpenShape)")
30
  prog = st.progress(0.0, "Idle")
31
 
32
 
33
- def load_data():
34
- # load the model
35
- prog.progress(0.05, "Preparing Point Cloud")
36
- if npy is not None:
37
- pc: numpy.ndarray = numpy.load(npy)
38
- elif model is not None:
39
- pc = misc_utils.trimesh_to_pc(trimesh.load(model, model.name.split(".")[-1]))
40
- elif objaid:
41
- prog.progress(0.1, "Downloading Objaverse Object")
42
- objamodel = objaverse.load_objects([objaid])[objaid]
43
- prog.progress(0.2, "Preparing Point Cloud")
44
- pc = misc_utils.trimesh_to_pc(trimesh.load(objamodel))
45
- else:
46
- raise ValueError("You have to supply 3D input!")
47
- prog.progress(0.25, "Preprocessing Point Cloud")
48
- assert pc.ndim == 2, "invalid pc shape: ndim = %d != 2" % pc.ndim
49
- assert pc.shape[1] in [3, 6], "invalid pc shape: should have 3/6 channels, got %d" % pc.shape[1]
50
- if swap_yz_axes:
51
- pc[:, [1, 2]] = pc[:, [2, 1]]
52
- pc[:, :3] = pc[:, :3] - numpy.mean(pc[:, :3], axis=0)
53
- pc[:, :3] = pc[:, :3] / numpy.linalg.norm(pc[:, :3], axis=-1).max()
54
- if pc.shape[1] == 3:
55
- pc = numpy.concatenate([pc, numpy.ones_like(pc)], axis=-1)
56
- prog.progress(0.3, "Preprocessed Point Cloud")
57
- return pc.astype(f32)
58
-
59
-
60
- def render_pc(pc):
61
- rand = numpy.random.permutation(len(pc))[:2048]
62
- pc = pc[rand]
63
- rgb = (pc[:, 3:] * 255).astype(numpy.uint8)
64
- g = go.Scatter3d(
65
- x=pc[:, 0], y=pc[:, 1], z=pc[:, 2],
66
- mode='markers',
67
- marker=dict(size=2, color=[f'rgb({rgb[i, 0]}, {rgb[i, 1]}, {rgb[i, 2]})' for i in range(len(pc))]),
68
- )
69
- fig = go.Figure(data=[g])
70
- fig.update_layout(scene_camera=dict(up=dict(x=0, y=1, z=0)))
71
- col1, col2 = st.columns(2)
72
- with col1:
73
- st.plotly_chart(fig, use_container_width=True)
74
- # st.caption("Point Cloud Preview")
75
- return col2
76
-
77
-
78
  try:
79
  if st.button("Run Classification on LVIS Categories"):
80
- pc = load_data()
81
- col2 = render_pc(pc)
82
  prog.progress(0.5, "Running Classification")
83
- pred = openshape.pred_lvis_sims(model_g14, pc)
84
  with col2:
85
  for i, (cat, sim) in zip(range(5), pred.items()):
86
  st.text(cat)
 
1
  import streamlit as st
2
+ from huggingface_hub import HfFolder, snapshot_download
3
  HfFolder().save_token(st.secrets['etoken'])
4
+ snapshot_download("OpenShape/openshape-demo-support", local_dir='.')
5
 
6
 
7
  import numpy
 
 
8
  import openshape
9
+ from openshape.demo import misc_utils, classification
 
10
 
11
 
12
  @st.cache_resource
 
15
 
16
 
17
  f32 = numpy.float32
18
+ # clip_model, clip_prep = load_openclip()
 
19
  model_g14 = openshape.load_pc_encoder('openshape-pointbert-vitg14-rgb')
20
 
21
 
22
  st.title("OpenShape Demo")
23
+ load_data = misc_utils.input_3d_shape()
 
 
 
24
  prog = st.progress(0.0, "Idle")
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  try:
28
  if st.button("Run Classification on LVIS Categories"):
29
+ pc = load_data(prog)
30
+ col2 = misc_utils.render_pc(pc)
31
  prog.progress(0.5, "Running Classification")
32
+ pred = classification.pred_lvis_sims(model_g14, pc)
33
  with col2:
34
  for i, (cat, sim) in zip(range(5), pred.items()):
35
  st.text(cat)
misc_utils.py DELETED
@@ -1,90 +0,0 @@
1
- import numpy
2
- import trimesh
3
- import trimesh.sample
4
- import trimesh.visual
5
- import trimesh.proximity
6
- import streamlit as st
7
- import matplotlib.pyplot as plotlib
8
-
9
-
10
- def get_bytes(x: str):
11
- import io, requests
12
- return io.BytesIO(requests.get(x).content)
13
-
14
-
15
- def get_image(x: str):
16
- try:
17
- return plotlib.imread(get_bytes(x), 'auto')
18
- except Exception:
19
- raise ValueError("Invalid image", x)
20
-
21
-
22
- def model_to_pc(mesh: trimesh.Trimesh, n_sample_points=10000):
23
- f32 = numpy.float32
24
- rad = numpy.sqrt(mesh.area / (3 * n_sample_points))
25
- for _ in range(24):
26
- pcd, face_idx = trimesh.sample.sample_surface_even(mesh, n_sample_points, rad)
27
- rad *= 0.85
28
- if len(pcd) == n_sample_points:
29
- break
30
- else:
31
- raise ValueError("Bad geometry, cannot finish sampling.", mesh.area)
32
- if isinstance(mesh.visual, trimesh.visual.ColorVisuals):
33
- rgba = mesh.visual.face_colors[face_idx]
34
- elif isinstance(mesh.visual, trimesh.visual.TextureVisuals):
35
- bc = trimesh.proximity.points_to_barycentric(mesh.triangles[face_idx], pcd)
36
- if mesh.visual.uv is None or len(mesh.visual.uv) < mesh.faces[face_idx].max():
37
- uv = numpy.zeros([len(bc), 2])
38
- st.warning("Invalid UV, filling with zeroes")
39
- else:
40
- uv = numpy.einsum('ntc,nt->nc', mesh.visual.uv[mesh.faces[face_idx]], bc)
41
- material = mesh.visual.material
42
- if hasattr(material, 'materials'):
43
- if len(material.materials) == 0:
44
- rgba = numpy.ones_like(pcd) * 0.8
45
- texture = None
46
- st.warning("Empty MultiMaterial found, falling back to light grey")
47
- else:
48
- material = material.materials[0]
49
- if hasattr(material, 'image'):
50
- texture = material.image
51
- if texture is None:
52
- rgba = numpy.zeros([len(uv), len(material.main_color)]) + material.main_color
53
- elif hasattr(material, 'baseColorTexture'):
54
- texture = material.baseColorTexture
55
- if texture is None:
56
- rgba = numpy.zeros([len(uv), len(material.main_color)]) + material.main_color
57
- else:
58
- texture = None
59
- rgba = numpy.ones_like(pcd) * 0.8
60
- st.warning("Unknown material, falling back to light grey")
61
- if texture is not None:
62
- rgba = trimesh.visual.uv_to_interpolated_color(uv, texture)
63
- if rgba.max() > 1:
64
- if rgba.max() > 255:
65
- rgba = rgba.astype(f32) / rgba.max()
66
- else:
67
- rgba = rgba.astype(f32) / 255.0
68
- return numpy.concatenate([numpy.array(pcd, f32), numpy.array(rgba, f32)[:, :3]], axis=-1)
69
-
70
-
71
- def trimesh_to_pc(scene_or_mesh):
72
- if isinstance(scene_or_mesh, trimesh.Scene):
73
- meshes = []
74
- for node_name in scene_or_mesh.graph.nodes_geometry:
75
- # which geometry does this node refer to
76
- transform, geometry_name = scene_or_mesh.graph[node_name]
77
-
78
- # get the actual potential mesh instance
79
- geometry = scene_or_mesh.geometry[geometry_name].copy()
80
- if not hasattr(geometry, 'triangles'):
81
- continue
82
- geometry: trimesh.Trimesh
83
- geometry = geometry.apply_transform(transform)
84
- meshes.append(model_to_pc(geometry, 10000 // len(scene_or_mesh.geometry)))
85
- if not len(meshes):
86
- raise ValueError("Unsupported mesh object: no triangles found")
87
- return numpy.concatenate(meshes)
88
- else:
89
- assert isinstance(scene_or_mesh, trimesh.Trimesh)
90
- return model_to_pc(scene_or_mesh, 10000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openshape/__init__.py DELETED
@@ -1,52 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from huggingface_hub import hf_hub_download
4
- from .ppat_rgb import Projected, PointPatchTransformer
5
-
6
-
7
- def module(state_dict: dict, name):
8
- return {'.'.join(k.split('.')[1:]): v for k, v in state_dict.items() if k.startswith(name + '.')}
9
-
10
-
11
- def G14(s):
12
- model = Projected(
13
- PointPatchTransformer(512, 12, 8, 512*3, 256, 384, 0.2, 64, 6),
14
- nn.Linear(512, 1280)
15
- )
16
- model.load_state_dict(module(s['state_dict'], 'module'))
17
- return model
18
-
19
-
20
- def L14(s):
21
- model = Projected(
22
- PointPatchTransformer(512, 12, 8, 1024, 128, 64, 0.4, 256, 6),
23
- nn.Linear(512, 768)
24
- )
25
- model.load_state_dict(module(s, 'pc_encoder'))
26
- return model
27
-
28
-
29
- def B32(s):
30
- model = PointPatchTransformer(512, 12, 8, 1024, 128, 64, 0.4, 256, 6)
31
- model.load_state_dict(module(s, 'pc_encoder'))
32
- return model
33
-
34
-
35
- model_list = {
36
- "openshape-pointbert-vitb32-rgb": B32,
37
- "openshape-pointbert-vitl14-rgb": L14,
38
- "openshape-pointbert-vitg14-rgb": G14,
39
- }
40
-
41
-
42
- def load_pc_encoder(name):
43
- s = torch.load(hf_hub_download("OpenShape/" + name, "model.pt", token=True), map_location='cpu')
44
- model = model_list[name](s).eval()
45
- if torch.cuda.is_available():
46
- model.cuda()
47
- return model
48
-
49
-
50
- # only import the functions in demo!
51
- # from .sd_pc2img import pc_to_image
52
- from .classification import pred_lvis_sims
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openshape/classification.py DELETED
@@ -1,13 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from collections import OrderedDict
4
- from . import lvis
5
-
6
-
7
- @torch.no_grad()
8
- def pred_lvis_sims(pc_encoder: torch.nn.Module, pc):
9
- ref_dev = next(pc_encoder.parameters()).device
10
- enc = pc_encoder(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu()
11
- sim = torch.matmul(F.normalize(lvis.feats, dim=-1), F.normalize(enc, dim=-1).squeeze())
12
- argsort = torch.argsort(sim, descending=True)
13
- return OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openshape/lvis.py DELETED
@@ -1,1162 +0,0 @@
1
- import os
2
- import torch
3
-
4
-
5
- feats = torch.load(os.path.join(os.path.dirname(__file__), 'lvis_cats.pt'))
6
- categories = [
7
- 'Band_Aid',
8
- 'Bible',
9
- 'CD_player',
10
- 'Christmas_tree',
11
- 'Dixie_cup',
12
- 'Ferris_wheel',
13
- 'Lego',
14
- 'Rollerblade',
15
- 'Sharpie',
16
- 'Tabasco_sauce',
17
- 'aerosol_can',
18
- 'air_conditioner',
19
- 'airplane',
20
- 'alarm_clock',
21
- 'alcohol',
22
- 'alligator',
23
- 'almond',
24
- 'ambulance',
25
- 'amplifier',
26
- 'anklet',
27
- 'antenna',
28
- 'apple',
29
- 'apricot',
30
- 'apron',
31
- 'aquarium',
32
- 'arctic_(type_of_shoe)',
33
- 'armband',
34
- 'armchair',
35
- 'armoire',
36
- 'armor',
37
- 'army_tank',
38
- 'artichoke',
39
- 'ashtray',
40
- 'asparagus',
41
- 'atomizer',
42
- 'automatic_washer',
43
- 'avocado',
44
- 'award',
45
- 'awning',
46
- 'ax',
47
- 'baboon',
48
- 'baby_buggy',
49
- 'backpack',
50
- 'bagel',
51
- 'baguet',
52
- 'bait',
53
- 'ball',
54
- 'ballet_skirt',
55
- 'balloon',
56
- 'bamboo',
57
- 'banana',
58
- 'bandage',
59
- 'bandanna',
60
- 'banjo',
61
- 'banner',
62
- 'barbell',
63
- 'barge',
64
- 'barrel',
65
- 'barrow',
66
- 'baseball',
67
- 'baseball_bat',
68
- 'baseball_cap',
69
- 'baseball_glove',
70
- 'basket',
71
- 'basketball',
72
- 'basketball_backboard',
73
- 'bass_horn',
74
- 'bat_(animal)',
75
- 'bath_mat',
76
- 'bath_towel',
77
- 'bathrobe',
78
- 'bathtub',
79
- 'battery',
80
- 'beachball',
81
- 'bead',
82
- 'beanbag',
83
- 'beanie',
84
- 'bear',
85
- 'bed',
86
- 'bedpan',
87
- 'bedspread',
88
- 'beef_(food)',
89
- 'beeper',
90
- 'beer_bottle',
91
- 'beer_can',
92
- 'beetle',
93
- 'bell',
94
- 'bell_pepper',
95
- 'belt',
96
- 'belt_buckle',
97
- 'bench',
98
- 'beret',
99
- 'bicycle',
100
- 'billboard',
101
- 'binder',
102
- 'binoculars',
103
- 'bird',
104
- 'birdbath',
105
- 'birdcage',
106
- 'birdfeeder',
107
- 'birdhouse',
108
- 'birthday_cake',
109
- 'birthday_card',
110
- 'blackberry',
111
- 'blackboard',
112
- 'blanket',
113
- 'blazer',
114
- 'blender',
115
- 'blimp',
116
- 'blouse',
117
- 'blueberry',
118
- 'boat',
119
- 'bob',
120
- 'bobbin',
121
- 'boiled_egg',
122
- 'bolo_tie',
123
- 'bolt',
124
- 'bonnet',
125
- 'book',
126
- 'bookcase',
127
- 'booklet',
128
- 'bookmark',
129
- 'boom_microphone',
130
- 'boot',
131
- 'bottle',
132
- 'bottle_cap',
133
- 'bottle_opener',
134
- 'bouquet',
135
- 'bow-tie',
136
- 'bow_(decorative_ribbons)',
137
- 'bow_(weapon)',
138
- 'bowl',
139
- 'bowler_hat',
140
- 'bowling_ball',
141
- 'box',
142
- 'boxing_glove',
143
- 'bracelet',
144
- 'brass_plaque',
145
- 'brassiere',
146
- 'bread',
147
- 'bread-bin',
148
- 'breechcloth',
149
- 'bridal_gown',
150
- 'briefcase',
151
- 'broach',
152
- 'broccoli',
153
- 'broom',
154
- 'brownie',
155
- 'brussels_sprouts',
156
- 'bubble_gum',
157
- 'bucket',
158
- 'bulldog',
159
- 'bulldozer',
160
- 'bullet_train',
161
- 'bulletin_board',
162
- 'bulletproof_vest',
163
- 'bullhorn',
164
- 'bun',
165
- 'bunk_bed',
166
- 'buoy',
167
- 'burrito',
168
- 'bus_(vehicle)',
169
- 'business_card',
170
- 'butter',
171
- 'butterfly',
172
- 'button',
173
- 'cab_(taxi)',
174
- 'cabana',
175
- 'cabin_car',
176
- 'cabinet',
177
- 'cake',
178
- 'calculator',
179
- 'calendar',
180
- 'calf',
181
- 'camcorder',
182
- 'camel',
183
- 'camera',
184
- 'camera_lens',
185
- 'camper_(vehicle)',
186
- 'can',
187
- 'can_opener',
188
- 'candle',
189
- 'candle_holder',
190
- 'candy_bar',
191
- 'candy_cane',
192
- 'canister',
193
- 'canoe',
194
- 'cantaloup',
195
- 'canteen',
196
- 'cap_(headwear)',
197
- 'cape',
198
- 'cappuccino',
199
- 'car_(automobile)',
200
- 'car_battery',
201
- 'card',
202
- 'cardigan',
203
- 'cargo_ship',
204
- 'carnation',
205
- 'carrot',
206
- 'cart',
207
- 'carton',
208
- 'cash_register',
209
- 'casserole',
210
- 'cassette',
211
- 'cast',
212
- 'cat',
213
- 'cauliflower',
214
- 'cayenne_(spice)',
215
- 'celery',
216
- 'cellular_telephone',
217
- 'chair',
218
- 'chaise_longue',
219
- 'chalice',
220
- 'chandelier',
221
- 'checkbook',
222
- 'checkerboard',
223
- 'cherry',
224
- 'chessboard',
225
- 'chicken_(animal)',
226
- 'chili_(vegetable)',
227
- 'chime',
228
- 'chinaware',
229
- 'chocolate_bar',
230
- 'chocolate_cake',
231
- 'chocolate_milk',
232
- 'chocolate_mousse',
233
- 'choker',
234
- 'chopping_board',
235
- 'chopstick',
236
- 'cider',
237
- 'cigar_box',
238
- 'cigarette',
239
- 'cigarette_case',
240
- 'cincture',
241
- 'cistern',
242
- 'clarinet',
243
- 'clasp',
244
- 'cleansing_agent',
245
- 'cleat_(for_securing_rope)',
246
- 'clementine',
247
- 'clip',
248
- 'clipboard',
249
- 'clippers_(for_plants)',
250
- 'cloak',
251
- 'clock',
252
- 'clock_tower',
253
- 'clothes_hamper',
254
- 'clothespin',
255
- 'clutch_bag',
256
- 'coaster',
257
- 'coat',
258
- 'coat_hanger',
259
- 'coatrack',
260
- 'cock',
261
- 'cockroach',
262
- 'cocoa_(beverage)',
263
- 'coconut',
264
- 'coffee_maker',
265
- 'coffee_table',
266
- 'coffeepot',
267
- 'coil',
268
- 'coin',
269
- 'colander',
270
- 'coloring_material',
271
- 'combination_lock',
272
- 'comic_book',
273
- 'compass',
274
- 'computer_keyboard',
275
- 'condiment',
276
- 'cone',
277
- 'control',
278
- 'convertible_(automobile)',
279
- 'cooker',
280
- 'cookie',
281
- 'cooking_utensil',
282
- 'cooler_(for_food)',
283
- 'cork_(bottle_plug)',
284
- 'corkboard',
285
- 'corkscrew',
286
- 'cornbread',
287
- 'cornet',
288
- 'cornice',
289
- 'cornmeal',
290
- 'corset',
291
- 'costume',
292
- 'cougar',
293
- 'cover',
294
- 'coverall',
295
- 'cow',
296
- 'cowbell',
297
- 'cowboy_hat',
298
- 'crab_(animal)',
299
- 'crabmeat',
300
- 'cracker',
301
- 'crape',
302
- 'crate',
303
- 'crawfish',
304
- 'crayon',
305
- 'cream_pitcher',
306
- 'crescent_roll',
307
- 'crib',
308
- 'crisp_(potato_chip)',
309
- 'crossbar',
310
- 'crouton',
311
- 'crow',
312
- 'crowbar',
313
- 'crown',
314
- 'crucifix',
315
- 'cruise_ship',
316
- 'crutch',
317
- 'cub_(animal)',
318
- 'cube',
319
- 'cucumber',
320
- 'cufflink',
321
- 'cup',
322
- 'cupboard',
323
- 'cupcake',
324
- 'curtain',
325
- 'cushion',
326
- 'cylinder',
327
- 'cymbal',
328
- 'dagger',
329
- 'dalmatian',
330
- 'dartboard',
331
- 'date_(fruit)',
332
- 'deadbolt',
333
- 'deck_chair',
334
- 'deer',
335
- 'desk',
336
- 'detergent',
337
- 'diaper',
338
- 'diary',
339
- 'die',
340
- 'dinghy',
341
- 'dining_table',
342
- 'dirt_bike',
343
- 'dish',
344
- 'dish_antenna',
345
- 'dishrag',
346
- 'dishtowel',
347
- 'dishwasher',
348
- 'dishwasher_detergent',
349
- 'dispenser',
350
- 'dog',
351
- 'dog_collar',
352
- 'doll',
353
- 'dollar',
354
- 'dollhouse',
355
- 'dolphin',
356
- 'domestic_ass',
357
- 'doorknob',
358
- 'doormat',
359
- 'doughnut',
360
- 'dove',
361
- 'dragonfly',
362
- 'drawer',
363
- 'dress',
364
- 'dress_hat',
365
- 'dress_suit',
366
- 'dresser',
367
- 'drill',
368
- 'drone',
369
- 'drum_(musical_instrument)',
370
- 'drumstick',
371
- 'duck',
372
- 'duckling',
373
- 'duct_tape',
374
- 'duffel_bag',
375
- 'dumbbell',
376
- 'dumpster',
377
- 'dustpan',
378
- 'eagle',
379
- 'earphone',
380
- 'earplug',
381
- 'earring',
382
- 'easel',
383
- 'eclair',
384
- 'edible_corn',
385
- 'eel',
386
- 'egg',
387
- 'egg_roll',
388
- 'egg_yolk',
389
- 'eggbeater',
390
- 'eggplant',
391
- 'elephant',
392
- 'elevator_car',
393
- 'elk',
394
- 'envelope',
395
- 'eraser',
396
- 'escargot',
397
- 'eyepatch',
398
- 'falcon',
399
- 'fan',
400
- 'faucet',
401
- 'fedora',
402
- 'ferret',
403
- 'ferry',
404
- 'fig_(fruit)',
405
- 'fighter_jet',
406
- 'figurine',
407
- 'file_(tool)',
408
- 'file_cabinet',
409
- 'fire_alarm',
410
- 'fire_engine',
411
- 'fire_extinguisher',
412
- 'fire_hose',
413
- 'fireplace',
414
- 'fireplug',
415
- 'first-aid_kit',
416
- 'fish',
417
- 'fish_(food)',
418
- 'fishbowl',
419
- 'fishing_rod',
420
- 'flag',
421
- 'flagpole',
422
- 'flamingo',
423
- 'flannel',
424
- 'flap',
425
- 'flash',
426
- 'flashlight',
427
- 'fleece',
428
- 'flip-flop_(sandal)',
429
- 'flipper_(footwear)',
430
- 'flower_arrangement',
431
- 'flowerpot',
432
- 'flute_glass',
433
- 'foal',
434
- 'folding_chair',
435
- 'food_processor',
436
- 'football_(American)',
437
- 'football_helmet',
438
- 'footstool',
439
- 'fork',
440
- 'forklift',
441
- 'freight_car',
442
- 'freshener',
443
- 'frisbee',
444
- 'frog',
445
- 'fruit_juice',
446
- 'frying_pan',
447
- 'fume_hood',
448
- 'funnel',
449
- 'futon',
450
- 'gameboard',
451
- 'garbage',
452
- 'garbage_truck',
453
- 'garden_hose',
454
- 'gargle',
455
- 'gargoyle',
456
- 'garlic',
457
- 'gasmask',
458
- 'gazelle',
459
- 'gelatin',
460
- 'gemstone',
461
- 'generator',
462
- 'giant_panda',
463
- 'gift_wrap',
464
- 'ginger',
465
- 'giraffe',
466
- 'glass_(drink_container)',
467
- 'globe',
468
- 'glove',
469
- 'goat',
470
- 'goggles',
471
- 'goldfish',
472
- 'golf_club',
473
- 'golfcart',
474
- 'gondola_(boat)',
475
- 'goose',
476
- 'gorilla',
477
- 'gourd',
478
- 'grape',
479
- 'grater',
480
- 'gravestone',
481
- 'gravy_boat',
482
- 'green_bean',
483
- 'green_onion',
484
- 'grill',
485
- 'grits',
486
- 'grizzly',
487
- 'grocery_bag',
488
- 'guitar',
489
- 'gull',
490
- 'gun',
491
- 'hair_dryer',
492
- 'hairbrush',
493
- 'hairnet',
494
- 'halter_top',
495
- 'ham',
496
- 'hamburger',
497
- 'hammer',
498
- 'hammock',
499
- 'hamper',
500
- 'hamster',
501
- 'hand_glass',
502
- 'hand_towel',
503
- 'handbag',
504
- 'handcart',
505
- 'handcuff',
506
- 'handkerchief',
507
- 'handle',
508
- 'handsaw',
509
- 'hardback_book',
510
- 'harmonium',
511
- 'hat',
512
- 'hatbox',
513
- 'headband',
514
- 'headboard',
515
- 'headlight',
516
- 'headscarf',
517
- 'headset',
518
- 'headstall_(for_horses)',
519
- 'heart',
520
- 'heater',
521
- 'helicopter',
522
- 'helmet',
523
- 'heron',
524
- 'highchair',
525
- 'hinge',
526
- 'hippopotamus',
527
- 'hockey_stick',
528
- 'hog',
529
- 'honey',
530
- 'hook',
531
- 'hookah',
532
- 'horned_cow',
533
- 'hornet',
534
- 'horse',
535
- 'horse_buggy',
536
- 'horse_carriage',
537
- 'hose',
538
- 'hot-air_balloon',
539
- 'hot_sauce',
540
- 'hotplate',
541
- 'hourglass',
542
- 'houseboat',
543
- 'hummingbird',
544
- 'iPod',
545
- 'ice_maker',
546
- 'ice_pack',
547
- 'ice_skate',
548
- 'icecream',
549
- 'identity_card',
550
- 'igniter',
551
- 'inhaler',
552
- 'inkpad',
553
- 'iron_(for_clothing)',
554
- 'ironing_board',
555
- 'jacket',
556
- 'jam',
557
- 'jar',
558
- 'jean',
559
- 'jeep',
560
- 'jersey',
561
- 'jet_plane',
562
- 'jewel',
563
- 'jewelry',
564
- 'joystick',
565
- 'jumpsuit',
566
- 'kayak',
567
- 'keg',
568
- 'kennel',
569
- 'kettle',
570
- 'key',
571
- 'keycard',
572
- 'kilt',
573
- 'kimono',
574
- 'kitchen_sink',
575
- 'kitchen_table',
576
- 'kite',
577
- 'kitten',
578
- 'kiwi_fruit',
579
- 'knee_pad',
580
- 'knife',
581
- 'knitting_needle',
582
- 'knob',
583
- 'knocker_(on_a_door)',
584
- 'koala',
585
- 'lab_coat',
586
- 'ladder',
587
- 'ladle',
588
- 'ladybug',
589
- 'lamb-chop',
590
- 'lamb_(animal)',
591
- 'lamp',
592
- 'lamppost',
593
- 'lampshade',
594
- 'lantern',
595
- 'laptop_computer',
596
- 'lasagna',
597
- 'latch',
598
- 'lawn_mower',
599
- 'leather',
600
- 'legging_(clothing)',
601
- 'legume',
602
- 'lemon',
603
- 'lemonade',
604
- 'lettuce',
605
- 'license_plate',
606
- 'life_buoy',
607
- 'life_jacket',
608
- 'lightbulb',
609
- 'lightning_rod',
610
- 'lime',
611
- 'limousine',
612
- 'lion',
613
- 'lip_balm',
614
- 'liquor',
615
- 'lizard',
616
- 'locker',
617
- 'log',
618
- 'lollipop',
619
- 'loveseat',
620
- 'machine_gun',
621
- 'magazine',
622
- 'magnet',
623
- 'mail_slot',
624
- 'mailbox_(at_home)',
625
- 'mallard',
626
- 'mallet',
627
- 'mammoth',
628
- 'manatee',
629
- 'mandarin_orange',
630
- 'manger',
631
- 'manhole',
632
- 'map',
633
- 'marker',
634
- 'martini',
635
- 'mascot',
636
- 'mashed_potato',
637
- 'mask',
638
- 'mast',
639
- 'mat_(gym_equipment)',
640
- 'matchbox',
641
- 'mattress',
642
- 'measuring_cup',
643
- 'measuring_stick',
644
- 'meatball',
645
- 'medicine',
646
- 'melon',
647
- 'microphone',
648
- 'microscope',
649
- 'microwave_oven',
650
- 'milestone',
651
- 'milk',
652
- 'milk_can',
653
- 'milkshake',
654
- 'minivan',
655
- 'mint_candy',
656
- 'mirror',
657
- 'mitten',
658
- 'mixer_(kitchen_tool)',
659
- 'money',
660
- 'monitor_(computer_equipment) computer_monitor',
661
- 'monkey',
662
- 'mop',
663
- 'motor',
664
- 'motor_scooter',
665
- 'motor_vehicle',
666
- 'motorcycle',
667
- 'mound_(baseball)',
668
- 'mouse_(computer_equipment)',
669
- 'mousepad',
670
- 'muffin',
671
- 'mug',
672
- 'mushroom',
673
- 'music_stool',
674
- 'musical_instrument',
675
- 'nailfile',
676
- 'napkin',
677
- 'neckerchief',
678
- 'necklace',
679
- 'necktie',
680
- 'needle',
681
- 'nest',
682
- 'newspaper',
683
- 'newsstand',
684
- 'nightshirt',
685
- 'notebook',
686
- 'notepad',
687
- 'nut',
688
- 'nutcracker',
689
- 'oar',
690
- 'octopus_(animal)',
691
- 'octopus_(food)',
692
- 'oil_lamp',
693
- 'olive_oil',
694
- 'omelet',
695
- 'onion',
696
- 'orange_(fruit)',
697
- 'orange_juice',
698
- 'ostrich',
699
- 'ottoman',
700
- 'oven',
701
- 'overalls_(clothing)',
702
- 'owl',
703
- 'pacifier',
704
- 'packet',
705
- 'paddle',
706
- 'padlock',
707
- 'paintbrush',
708
- 'painting',
709
- 'pajamas',
710
- 'palette',
711
- 'pan_(for_cooking)',
712
- 'pan_(metal_container)',
713
- 'pancake',
714
- 'papaya',
715
- 'paper_plate',
716
- 'paper_towel',
717
- 'paperback_book',
718
- 'paperweight',
719
- 'parachute',
720
- 'parakeet',
721
- 'parasail_(sports)',
722
- 'parasol',
723
- 'parchment',
724
- 'parka',
725
- 'parking_meter',
726
- 'parrot',
727
- 'passenger_car_(part_of_a_train)',
728
- 'passenger_ship',
729
- 'passport',
730
- 'pastry',
731
- 'patty_(food)',
732
- 'pea_(food)',
733
- 'peach',
734
- 'peanut_butter',
735
- 'pear',
736
- 'peeler_(tool_for_fruit_and_vegetables)',
737
- 'pegboard',
738
- 'pelican',
739
- 'pen',
740
- 'pencil',
741
- 'pencil_box',
742
- 'pencil_sharpener',
743
- 'pendulum',
744
- 'penguin',
745
- 'pennant',
746
- 'penny_(coin)',
747
- 'pepper',
748
- 'pepper_mill',
749
- 'perfume',
750
- 'persimmon',
751
- 'person',
752
- 'pet',
753
- 'pew_(church_bench)',
754
- 'phonebook',
755
- 'phonograph_record',
756
- 'piano',
757
- 'pickle',
758
- 'pickup_truck',
759
- 'pie',
760
- 'pigeon',
761
- 'piggy_bank',
762
- 'pillow',
763
- 'pineapple',
764
- 'pinecone',
765
- 'ping-pong_ball',
766
- 'pinwheel',
767
- 'pipe',
768
- 'pipe_bowl',
769
- 'pirate_flag',
770
- 'pistol',
771
- 'pita_(bread)',
772
- 'pitcher_(vessel_for_liquid)',
773
- 'pitchfork',
774
- 'pizza',
775
- 'place_mat',
776
- 'plastic_bag',
777
- 'plate',
778
- 'platter',
779
- 'playpen',
780
- 'pliers',
781
- 'plow_(farm_equipment)',
782
- 'plume',
783
- 'pocket_watch',
784
- 'pocketknife',
785
- 'poker_(fire_stirring_tool)',
786
- 'poker_chip',
787
- 'polar_bear',
788
- 'pole',
789
- 'police_cruiser',
790
- 'polo_shirt',
791
- 'poncho',
792
- 'pony',
793
- 'pool_table',
794
- 'pop_(soda)',
795
- 'popsicle',
796
- 'postbox_(public)',
797
- 'postcard',
798
- 'poster',
799
- 'pot',
800
- 'potato',
801
- 'potholder',
802
- 'pottery',
803
- 'pouch',
804
- 'power_shovel',
805
- 'prawn',
806
- 'pretzel',
807
- 'printer',
808
- 'projectile_(weapon)',
809
- 'projector',
810
- 'propeller',
811
- 'prune',
812
- 'pudding',
813
- 'puffer_(fish)',
814
- 'puffin',
815
- 'pug-dog',
816
- 'pumpkin',
817
- 'puncher',
818
- 'puppet',
819
- 'puppy',
820
- 'quesadilla',
821
- 'quiche',
822
- 'quilt',
823
- 'rabbit',
824
- 'race_car',
825
- 'racket',
826
- 'radar',
827
- 'radiator',
828
- 'radio_receiver',
829
- 'radish',
830
- 'raft',
831
- 'rag_doll',
832
- 'railcar_(part_of_a_train)',
833
- 'raincoat',
834
- 'ram_(animal)',
835
- 'raspberry',
836
- 'rat',
837
- 'reamer_(juicer)',
838
- 'rearview_mirror',
839
- 'receipt',
840
- 'recliner',
841
- 'record_player',
842
- 'reflector',
843
- 'refrigerator',
844
- 'remote_control',
845
- 'rhinoceros',
846
- 'rib_(food)',
847
- 'rifle',
848
- 'ring',
849
- 'river_boat',
850
- 'road_map',
851
- 'robe',
852
- 'rocking_chair',
853
- 'rodent',
854
- 'roller_skate',
855
- 'rolling_pin',
856
- 'root_beer',
857
- 'router_(computer_equipment)',
858
- 'rubber_band',
859
- 'runner_(carpet)',
860
- 'saddle_(on_an_animal)',
861
- 'saddle_blanket',
862
- 'saddlebag',
863
- 'safety_pin',
864
- 'sail',
865
- 'salad',
866
- 'salad_plate',
867
- 'salami',
868
- 'salmon_(fish)',
869
- 'salmon_(food)',
870
- 'salsa',
871
- 'saltshaker',
872
- 'sandal_(type_of_shoe)',
873
- 'sandwich',
874
- 'satchel',
875
- 'saucepan',
876
- 'saucer',
877
- 'sausage',
878
- 'sawhorse',
879
- 'saxophone',
880
- 'scale_(measuring_instrument)',
881
- 'scarecrow',
882
- 'scarf',
883
- 'school_bus',
884
- 'scissors',
885
- 'scoreboard',
886
- 'scraper',
887
- 'screwdriver',
888
- 'scrubbing_brush',
889
- 'sculpture',
890
- 'seabird',
891
- 'seahorse',
892
- 'seaplane',
893
- 'seashell',
894
- 'sewing_machine',
895
- 'shaker',
896
- 'shampoo',
897
- 'shark',
898
- 'sharpener',
899
- 'shaver_(electric)',
900
- 'shaving_cream',
901
- 'shawl',
902
- 'shears',
903
- 'sheep',
904
- 'shepherd_dog',
905
- 'sherbert',
906
- 'shield',
907
- 'shirt',
908
- 'shoe',
909
- 'shopping_bag',
910
- 'shopping_cart',
911
- 'short_pants',
912
- 'shot_glass',
913
- 'shoulder_bag',
914
- 'shovel',
915
- 'shower_cap',
916
- 'shower_curtain',
917
- 'shower_head',
918
- 'shredder_(for_paper)',
919
- 'signboard',
920
- 'silo',
921
- 'sink',
922
- 'skateboard',
923
- 'skewer',
924
- 'ski',
925
- 'ski_boot',
926
- 'ski_parka',
927
- 'ski_pole',
928
- 'skirt',
929
- 'skullcap',
930
- 'sled',
931
- 'sleeping_bag',
932
- 'slide',
933
- 'slipper_(footwear)',
934
- 'smoothie',
935
- 'snake',
936
- 'snowboard',
937
- 'snowman',
938
- 'snowmobile',
939
- 'soap',
940
- 'soccer_ball',
941
- 'sock',
942
- 'sofa',
943
- 'sofa_bed',
944
- 'softball',
945
- 'solar_array',
946
- 'sombrero',
947
- 'soup',
948
- 'soup_bowl',
949
- 'soupspoon',
950
- 'soya_milk',
951
- 'space_shuttle',
952
- 'sparkler_(fireworks)',
953
- 'spatula',
954
- 'speaker_(stero_equipment)',
955
- 'spear',
956
- 'spectacles',
957
- 'spice_rack',
958
- 'spider',
959
- 'sponge',
960
- 'spoon',
961
- 'sportswear',
962
- 'spotlight',
963
- 'squid_(food)',
964
- 'squirrel',
965
- 'stagecoach',
966
- 'stapler_(stapling_machine)',
967
- 'starfish',
968
- 'statue_(sculpture)',
969
- 'steak_(food)',
970
- 'steak_knife',
971
- 'steering_wheel',
972
- 'step_stool',
973
- 'stepladder',
974
- 'stereo_(sound_system)',
975
- 'stew',
976
- 'stirrer',
977
- 'stirrup',
978
- 'stool',
979
- 'stop_sign',
980
- 'stove',
981
- 'strainer',
982
- 'strap',
983
- 'straw_(for_drinking)',
984
- 'strawberry',
985
- 'street_sign',
986
- 'streetlight',
987
- 'string_cheese',
988
- 'stylus',
989
- 'subwoofer',
990
- 'sugar_bowl',
991
- 'sugarcane_(plant)',
992
- 'suit_(clothing)',
993
- 'suitcase',
994
- 'sunflower',
995
- 'sunglasses',
996
- 'sunhat',
997
- 'surfboard',
998
- 'sushi',
999
- 'suspenders',
1000
- 'sweat_pants',
1001
- 'sweatband',
1002
- 'sweater',
1003
- 'sweatshirt',
1004
- 'sweet_potato',
1005
- 'swimsuit',
1006
- 'sword',
1007
- 'syringe',
1008
- 'table',
1009
- 'table-tennis_table',
1010
- 'table_lamp',
1011
- 'tablecloth',
1012
- 'tachometer',
1013
- 'taco',
1014
- 'tag',
1015
- 'taillight',
1016
- 'tambourine',
1017
- 'tank_(storage_vessel)',
1018
- 'tank_top_(clothing)',
1019
- 'tape_(sticky_cloth_or_paper)',
1020
- 'tape_measure',
1021
- 'tapestry',
1022
- 'tarp',
1023
- 'tartan',
1024
- 'tassel',
1025
- 'teacup',
1026
- 'teakettle',
1027
- 'teapot',
1028
- 'teddy_bear',
1029
- 'telephone',
1030
- 'telephone_booth',
1031
- 'telephone_pole',
1032
- 'telephoto_lens',
1033
- 'television_camera',
1034
- 'television_set',
1035
- 'tennis_ball',
1036
- 'tennis_racket',
1037
- 'tequila',
1038
- 'thermometer',
1039
- 'thermos_bottle',
1040
- 'thermostat',
1041
- 'thimble',
1042
- 'thread',
1043
- 'thumbtack',
1044
- 'tiara',
1045
- 'tiger',
1046
- 'tights_(clothing)',
1047
- 'timer',
1048
- 'tinfoil',
1049
- 'tinsel',
1050
- 'tissue_paper',
1051
- 'toast_(food)',
1052
- 'toaster',
1053
- 'toaster_oven',
1054
- 'tobacco_pipe',
1055
- 'toilet',
1056
- 'toilet_tissue',
1057
- 'tomato',
1058
- 'tongs',
1059
- 'toolbox',
1060
- 'toothbrush',
1061
- 'toothpaste',
1062
- 'toothpick',
1063
- 'tortilla',
1064
- 'tote_bag',
1065
- 'tow_truck',
1066
- 'towel',
1067
- 'towel_rack',
1068
- 'toy',
1069
- 'tractor_(farm_equipment)',
1070
- 'traffic_light',
1071
- 'trailer_truck',
1072
- 'train_(railroad_vehicle)',
1073
- 'trampoline',
1074
- 'trash_can',
1075
- 'tray',
1076
- 'trench_coat',
1077
- 'triangle_(musical_instrument)',
1078
- 'tricycle',
1079
- 'tripod',
1080
- 'trophy_cup',
1081
- 'trousers',
1082
- 'truck',
1083
- 'truffle_(chocolate)',
1084
- 'trunk',
1085
- 'turban',
1086
- 'turkey_(food)',
1087
- 'turnip',
1088
- 'turtle',
1089
- 'turtleneck_(clothing)',
1090
- 'tux',
1091
- 'typewriter',
1092
- 'umbrella',
1093
- 'underdrawers',
1094
- 'underwear',
1095
- 'unicycle',
1096
- 'urinal',
1097
- 'urn',
1098
- 'vacuum_cleaner',
1099
- 'vase',
1100
- 'veil',
1101
- 'vending_machine',
1102
- 'vent',
1103
- 'vest',
1104
- 'videotape',
1105
- 'vinegar',
1106
- 'violin',
1107
- 'visor',
1108
- 'vodka',
1109
- 'volleyball',
1110
- 'vulture',
1111
- 'waffle',
1112
- 'waffle_iron',
1113
- 'wagon',
1114
- 'walking_cane',
1115
- 'walking_stick',
1116
- 'wall_clock',
1117
- 'wall_socket',
1118
- 'wallet',
1119
- 'walrus',
1120
- 'wardrobe',
1121
- 'washbasin',
1122
- 'watch',
1123
- 'water_bottle',
1124
- 'water_cooler',
1125
- 'water_faucet',
1126
- 'water_gun',
1127
- 'water_heater',
1128
- 'water_jug',
1129
- 'water_scooter',
1130
- 'water_ski',
1131
- 'water_tower',
1132
- 'watering_can',
1133
- 'watermelon',
1134
- 'weathervane',
1135
- 'webcam',
1136
- 'wedding_cake',
1137
- 'wedding_ring',
1138
- 'wet_suit',
1139
- 'wheel',
1140
- 'wheelchair',
1141
- 'whipped_cream',
1142
- 'wig',
1143
- 'wind_chime',
1144
- 'windmill',
1145
- 'window_box_(for_plants)',
1146
- 'windsock',
1147
- 'wine_bottle',
1148
- 'wine_bucket',
1149
- 'wineglass',
1150
- 'wok',
1151
- 'wolf',
1152
- 'wooden_leg',
1153
- 'wooden_spoon',
1154
- 'wreath',
1155
- 'wrench',
1156
- 'wristband',
1157
- 'wristlet',
1158
- 'yacht',
1159
- 'yogurt',
1160
- 'zebra',
1161
- 'zucchini'
1162
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openshape/lvis_cats.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:71baf2d3f89884a082f1db75d0e94ac9a3b8036553877a3fdd98861cd01c4aec
3
- size 5919467
 
 
 
 
openshape/pointnet_util.py DELETED
@@ -1,323 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from time import time
5
- import numpy as np
6
- import dgl.geometry
7
-
8
- def timeit(tag, t):
9
- print("{}: {}s".format(tag, time() - t))
10
- return time()
11
-
12
- def pc_normalize(pc):
13
- l = pc.shape[0]
14
- centroid = np.mean(pc, axis=0)
15
- pc = pc - centroid
16
- m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
17
- pc = pc / m
18
- return pc
19
-
20
- def square_distance(src, dst):
21
- """
22
- Calculate Euclid distance between each two points.
23
-
24
- src^T * dst = xn * xm + yn * ym + zn * zm;
25
- sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
26
- sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
27
- dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
28
- = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
29
-
30
- Input:
31
- src: source points, [B, N, C]
32
- dst: target points, [B, M, C]
33
- Output:
34
- dist: per-point square distance, [B, N, M]
35
- """
36
- B, N, _ = src.shape
37
- _, M, _ = dst.shape
38
- dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
39
- dist += torch.sum(src ** 2, -1).view(B, N, 1)
40
- dist += torch.sum(dst ** 2, -1).view(B, 1, M)
41
- return dist
42
-
43
-
44
- def index_points(points, idx):
45
- """
46
-
47
- Input:
48
- points: input points data, [B, N, C]
49
- idx: sample index data, [B, S]
50
- Return:
51
- new_points:, indexed points data, [B, S, C]
52
- """
53
- device = points.device
54
- B = points.shape[0]
55
- view_shape = list(idx.shape)
56
- view_shape[1:] = [1] * (len(view_shape) - 1)
57
- repeat_shape = list(idx.shape)
58
- repeat_shape[0] = 1
59
- batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
60
- new_points = points[batch_indices, idx, :]
61
- return new_points
62
-
63
-
64
- def farthest_point_sample(xyz, npoint):
65
- """
66
- Input:
67
- xyz: pointcloud data, [B, N, 3]
68
- npoint: number of samples
69
- Return:
70
- centroids: sampled pointcloud index, [B, npoint]
71
- """
72
- return dgl.geometry.farthest_point_sampler(xyz, npoint)
73
- device = xyz.device
74
- B, N, C = xyz.shape
75
- centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
76
- distance = torch.ones(B, N).to(device) * 1e10
77
- farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
78
- batch_indices = torch.arange(B, dtype=torch.long).to(device)
79
- for i in range(npoint):
80
- centroids[:, i] = farthest
81
- centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
82
- dist = torch.sum((xyz - centroid) ** 2, -1)
83
- mask = dist < distance
84
- distance[mask] = dist[mask]
85
- farthest = torch.max(distance, -1)[1]
86
- return centroids
87
-
88
-
89
- def query_ball_point(radius, nsample, xyz, new_xyz):
90
- """
91
- Input:
92
- radius: local region radius
93
- nsample: max sample number in local region
94
- xyz: all points, [B, N, 3]
95
- new_xyz: query points, [B, S, 3]
96
- Return:
97
- group_idx: grouped points index, [B, S, nsample]
98
- """
99
- device = xyz.device
100
- B, N, C = xyz.shape
101
- _, S, _ = new_xyz.shape
102
- group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
103
- sqrdists = square_distance(new_xyz, xyz)
104
- group_idx[sqrdists > radius ** 2] = N
105
- group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
106
- group_first = group_idx[..., :1].repeat([1, 1, nsample])
107
- mask = group_idx == N
108
- group_idx[mask] = group_first[mask]
109
- return group_idx
110
-
111
-
112
- def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
113
- """
114
- Input:
115
- npoint:
116
- radius:
117
- nsample:
118
- xyz: input points position data, [B, N, 3]
119
- points: input points data, [B, N, D]
120
- Return:
121
- new_xyz: sampled points position data, [B, npoint, nsample, 3]
122
- new_points: sampled points data, [B, npoint, nsample, 3+D]
123
- """
124
- B, N, C = xyz.shape
125
- S = npoint
126
- fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
127
- # torch.cuda.empty_cache()
128
- new_xyz = index_points(xyz, fps_idx)
129
- # torch.cuda.empty_cache()
130
- idx = query_ball_point(radius, nsample, xyz, new_xyz)
131
- # torch.cuda.empty_cache()
132
- grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
133
- # torch.cuda.empty_cache()
134
- grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
135
- # torch.cuda.empty_cache()
136
-
137
- if points is not None:
138
- grouped_points = index_points(points, idx)
139
- new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
140
- else:
141
- new_points = grouped_xyz_norm
142
- if returnfps:
143
- return new_xyz, new_points, grouped_xyz, fps_idx
144
- else:
145
- return new_xyz, new_points
146
-
147
-
148
- def sample_and_group_all(xyz, points):
149
- """
150
- Input:
151
- xyz: input points position data, [B, N, 3]
152
- points: input points data, [B, N, D]
153
- Return:
154
- new_xyz: sampled points position data, [B, 1, 3]
155
- new_points: sampled points data, [B, 1, N, 3+D]
156
- """
157
- device = xyz.device
158
- B, N, C = xyz.shape
159
- new_xyz = torch.zeros(B, 1, C).to(device)
160
- grouped_xyz = xyz.view(B, 1, N, C)
161
- if points is not None:
162
- new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
163
- else:
164
- new_points = grouped_xyz
165
- return new_xyz, new_points
166
-
167
-
168
- class PointNetSetAbstraction(nn.Module):
169
- def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
170
- super(PointNetSetAbstraction, self).__init__()
171
- self.npoint = npoint
172
- self.radius = radius
173
- self.nsample = nsample
174
- self.mlp_convs = nn.ModuleList()
175
- self.mlp_bns = nn.ModuleList()
176
- last_channel = in_channel
177
- for out_channel in mlp:
178
- self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
179
- self.mlp_bns.append(nn.BatchNorm2d(out_channel))
180
- last_channel = out_channel
181
- self.group_all = group_all
182
-
183
- def forward(self, xyz, points):
184
- """
185
- Input:
186
- xyz: input points position data, [B, C, N]
187
- points: input points data, [B, D, N]
188
- Return:
189
- new_xyz: sampled points position data, [B, C, S]
190
- new_points_concat: sample points feature data, [B, D', S]
191
- """
192
- xyz = xyz.permute(0, 2, 1)
193
- if points is not None:
194
- points = points.permute(0, 2, 1)
195
-
196
- if self.group_all:
197
- new_xyz, new_points = sample_and_group_all(xyz, points)
198
- else:
199
- new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
200
- # new_xyz: sampled points position data, [B, npoint, C]
201
- # new_points: sampled points data, [B, npoint, nsample, C+D]
202
- new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
203
- for i, conv in enumerate(self.mlp_convs):
204
- bn = self.mlp_bns[i]
205
- new_points = F.relu(bn(conv(new_points)))
206
-
207
- new_points = torch.max(new_points, 2)[0]
208
- new_xyz = new_xyz.permute(0, 2, 1)
209
- return new_xyz, new_points
210
-
211
-
212
- class PointNetSetAbstractionMsg(nn.Module):
213
- def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
214
- super(PointNetSetAbstractionMsg, self).__init__()
215
- self.npoint = npoint
216
- self.radius_list = radius_list
217
- self.nsample_list = nsample_list
218
- self.conv_blocks = nn.ModuleList()
219
- self.bn_blocks = nn.ModuleList()
220
- for i in range(len(mlp_list)):
221
- convs = nn.ModuleList()
222
- bns = nn.ModuleList()
223
- last_channel = in_channel + 3
224
- for out_channel in mlp_list[i]:
225
- convs.append(nn.Conv2d(last_channel, out_channel, 1))
226
- bns.append(nn.BatchNorm2d(out_channel))
227
- last_channel = out_channel
228
- self.conv_blocks.append(convs)
229
- self.bn_blocks.append(bns)
230
-
231
- def forward(self, xyz, points):
232
- """
233
- Input:
234
- xyz: input points position data, [B, C, N]
235
- points: input points data, [B, D, N]
236
- Return:
237
- new_xyz: sampled points position data, [B, C, S]
238
- new_points_concat: sample points feature data, [B, D', S]
239
- """
240
- xyz = xyz.permute(0, 2, 1)
241
- if points is not None:
242
- points = points.permute(0, 2, 1)
243
-
244
- B, N, C = xyz.shape
245
- S = self.npoint
246
- new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
247
- new_points_list = []
248
- for i, radius in enumerate(self.radius_list):
249
- K = self.nsample_list[i]
250
- group_idx = query_ball_point(radius, K, xyz, new_xyz)
251
- grouped_xyz = index_points(xyz, group_idx)
252
- grouped_xyz -= new_xyz.view(B, S, 1, C)
253
- if points is not None:
254
- grouped_points = index_points(points, group_idx)
255
- grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
256
- else:
257
- grouped_points = grouped_xyz
258
-
259
- grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]
260
- for j in range(len(self.conv_blocks[i])):
261
- conv = self.conv_blocks[i][j]
262
- bn = self.bn_blocks[i][j]
263
- grouped_points = F.relu(bn(conv(grouped_points)))
264
- new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
265
- new_points_list.append(new_points)
266
-
267
- new_xyz = new_xyz.permute(0, 2, 1)
268
- new_points_concat = torch.cat(new_points_list, dim=1)
269
- return new_xyz, new_points_concat
270
-
271
-
272
- class PointNetFeaturePropagation(nn.Module):
273
- def __init__(self, in_channel, mlp):
274
- super(PointNetFeaturePropagation, self).__init__()
275
- self.mlp_convs = nn.ModuleList()
276
- self.mlp_bns = nn.ModuleList()
277
- last_channel = in_channel
278
- for out_channel in mlp:
279
- self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
280
- self.mlp_bns.append(nn.BatchNorm1d(out_channel))
281
- last_channel = out_channel
282
-
283
- def forward(self, xyz1, xyz2, points1, points2):
284
- """
285
- Input:
286
- xyz1: input points position data, [B, C, N]
287
- xyz2: sampled input points position data, [B, C, S]
288
- points1: input points data, [B, D, N]
289
- points2: input points data, [B, D, S]
290
- Return:
291
- new_points: upsampled points data, [B, D', N]
292
- """
293
- xyz1 = xyz1.permute(0, 2, 1)
294
- xyz2 = xyz2.permute(0, 2, 1)
295
-
296
- points2 = points2.permute(0, 2, 1)
297
- B, N, C = xyz1.shape
298
- _, S, _ = xyz2.shape
299
-
300
- if S == 1:
301
- interpolated_points = points2.repeat(1, N, 1)
302
- else:
303
- dists = square_distance(xyz1, xyz2)
304
- dists, idx = dists.sort(dim=-1)
305
- dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
306
-
307
- dist_recip = 1.0 / (dists + 1e-8)
308
- norm = torch.sum(dist_recip, dim=2, keepdim=True)
309
- weight = dist_recip / norm
310
- interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
311
-
312
- if points1 is not None:
313
- points1 = points1.permute(0, 2, 1)
314
- new_points = torch.cat([points1, interpolated_points], dim=-1)
315
- else:
316
- new_points = interpolated_points
317
-
318
- new_points = new_points.permute(0, 2, 1)
319
- for i, conv in enumerate(self.mlp_convs):
320
- bn = self.mlp_bns[i]
321
- new_points = F.relu(bn(conv(new_points)))
322
- return new_points
323
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openshape/ppat_rgb.py DELETED
@@ -1,118 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch_redstone as rst
4
- from einops import rearrange
5
- from .pointnet_util import PointNetSetAbstraction
6
-
7
-
8
- class PreNorm(nn.Module):
9
- def __init__(self, dim, fn):
10
- super().__init__()
11
- self.norm = nn.LayerNorm(dim)
12
- self.fn = fn
13
- def forward(self, x, *extra_args, **kwargs):
14
- return self.fn(self.norm(x), *extra_args, **kwargs)
15
-
16
- class FeedForward(nn.Module):
17
- def __init__(self, dim, hidden_dim, dropout = 0.):
18
- super().__init__()
19
- self.net = nn.Sequential(
20
- nn.Linear(dim, hidden_dim),
21
- nn.GELU(),
22
- nn.Dropout(dropout),
23
- nn.Linear(hidden_dim, dim),
24
- nn.Dropout(dropout)
25
- )
26
- def forward(self, x):
27
- return self.net(x)
28
-
29
- class Attention(nn.Module):
30
- def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., rel_pe = False):
31
- super().__init__()
32
- inner_dim = dim_head * heads
33
- project_out = not (heads == 1 and dim_head == dim)
34
-
35
- self.heads = heads
36
- self.scale = dim_head ** -0.5
37
-
38
- self.attend = nn.Softmax(dim = -1)
39
- self.dropout = nn.Dropout(dropout)
40
-
41
- self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
42
-
43
- self.to_out = nn.Sequential(
44
- nn.Linear(inner_dim, dim),
45
- nn.Dropout(dropout)
46
- ) if project_out else nn.Identity()
47
-
48
- self.rel_pe = rel_pe
49
- if rel_pe:
50
- self.pe = nn.Sequential(nn.Conv2d(3, 64, 1), nn.ReLU(), nn.Conv2d(64, 1, 1))
51
-
52
- def forward(self, x, centroid_delta):
53
- qkv = self.to_qkv(x).chunk(3, dim = -1)
54
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
55
-
56
- pe = self.pe(centroid_delta) if self.rel_pe else 0
57
- dots = (torch.matmul(q, k.transpose(-1, -2)) + pe) * self.scale
58
-
59
- attn = self.attend(dots)
60
- attn = self.dropout(attn)
61
-
62
- out = torch.matmul(attn, v)
63
- out = rearrange(out, 'b h n d -> b n (h d)')
64
- return self.to_out(out)
65
-
66
-
67
- class Transformer(nn.Module):
68
- def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., rel_pe = False):
69
- super().__init__()
70
- self.layers = nn.ModuleList([])
71
- for _ in range(depth):
72
- self.layers.append(nn.ModuleList([
73
- PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, rel_pe = rel_pe)),
74
- PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
75
- ]))
76
- def forward(self, x, centroid_delta):
77
- for attn, ff in self.layers:
78
- x = attn(x, centroid_delta) + x
79
- x = ff(x) + x
80
- return x
81
-
82
-
83
- class PointPatchTransformer(nn.Module):
84
- def __init__(self, dim, depth, heads, mlp_dim, sa_dim, patches, prad, nsamp, in_dim=3, dim_head=64, rel_pe=False, patch_dropout=0) -> None:
85
- super().__init__()
86
- self.patches = patches
87
- self.patch_dropout = patch_dropout
88
- self.sa = PointNetSetAbstraction(npoint=patches, radius=prad, nsample=nsamp, in_channel=in_dim + 3, mlp=[64, 64, sa_dim], group_all=False)
89
- self.lift = nn.Sequential(nn.Conv1d(sa_dim + 3, dim, 1), rst.Lambda(lambda x: torch.permute(x, [0, 2, 1])), nn.LayerNorm([dim]))
90
- self.cls_token = nn.Parameter(torch.randn(dim))
91
- self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, 0.0, rel_pe)
92
-
93
- def forward(self, features):
94
- self.sa.npoint = self.patches
95
- if self.training:
96
- self.sa.npoint -= self.patch_dropout
97
- # print("input", features.shape)
98
- centroids, feature = self.sa(features[:, :3], features)
99
- # print("f", feature.shape, 'c', centroids.shape)
100
- x = self.lift(torch.cat([centroids, feature], dim=1))
101
-
102
- x = rst.supercat([self.cls_token, x], dim=-2)
103
- centroids = rst.supercat([centroids.new_zeros(1), centroids], dim=-1)
104
-
105
- centroid_delta = centroids.unsqueeze(-1) - centroids.unsqueeze(-2)
106
- x = self.transformer(x, centroid_delta)
107
-
108
- return x[:, 0]
109
-
110
-
111
- class Projected(nn.Module):
112
- def __init__(self, ppat, proj) -> None:
113
- super().__init__()
114
- self.ppat = ppat
115
- self.proj = proj
116
-
117
- def forward(self, features: torch.Tensor):
118
- return self.proj(self.ppat(features))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openshape/sd_pc2img.py DELETED
@@ -1,39 +0,0 @@
1
- import numpy
2
- import torch
3
- import torch_redstone as rst
4
- import transformers
5
- from diffusers import StableUnCLIPImg2ImgPipeline
6
-
7
-
8
- class Wrapper(transformers.modeling_utils.PreTrainedModel):
9
- def __init__(self) -> None:
10
- super().__init__(transformers.configuration_utils.PretrainedConfig())
11
- self.param = torch.nn.Parameter(torch.tensor(0.))
12
-
13
- def forward(self, x):
14
- return rst.ObjectProxy(image_embeds=x)
15
-
16
-
17
- pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
18
- "diffusers/stable-diffusion-2-1-unclip-i2i-l",
19
- image_encoder = Wrapper()
20
- )
21
- if torch.cuda.is_available():
22
- pipe = pipe.to('cuda:' + str(torch.cuda.current_device()))
23
- pipe.enable_model_cpu_offload(torch.cuda.current_device())
24
-
25
-
26
- @torch.no_grad()
27
- def pc_to_image(pc_encoder: torch.nn.Module, pc, prompt, noise_level, width, height, cfg_scale, num_steps, callback):
28
- ref_dev = next(pc_encoder.parameters()).device
29
- enc = pc_encoder(torch.tensor(pc.T[None], device=ref_dev))
30
- return pipe(
31
- prompt="best quality, super high resolution, " + prompt,
32
- negative_prompt="cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry",
33
- image=torch.nn.functional.normalize(enc, dim=-1) * (768 ** 0.5) / 2,
34
- width=width, height=height,
35
- guidance_scale=cfg_scale,
36
- noise_level=noise_level,
37
- callback=callback,
38
- num_inference_steps=num_steps
39
- ).images[0]