salma-remyx commited on
Commit
ebd9056
·
1 Parent(s): 8636313

update to VQASynth pipeline

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ checkpoints/ filter=lfs diff=lfs merge=lfs -text
37
+ checkpoints/depth_pro.pt filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04
2
+
3
+ ARG DEBIAN_FRONTEND=noninteractive
4
+ ENV CUDA_HOME /usr/local/cuda-11.8/
5
+
6
+ WORKDIR /app
7
+
8
+ ENV PATH="/usr/local/cuda-11.8/bin:${PATH}"
9
+ ENV LD_LIBRARY_PATH="/usr/local/cuda-11.8/lib64:${LD_LIBRARY_PATH}"
10
+
11
+ RUN apt-get update && apt-get install -y software-properties-common wget && \
12
+ add-apt-repository ppa:deadsnakes/ppa && \
13
+ apt-get update && \
14
+ apt-get install -y build-essential git wget curl && \
15
+ apt-get install -y python3.10 python3.10-dev python3.10-distutils python3-venv && \
16
+ update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1 && \
17
+ update-alternatives --set python3 /usr/bin/python3.10 && \
18
+ apt-get install -y zlib1g-dev libexpat1-dev
19
+
20
+ RUN wget https://github.com/Kitware/CMake/releases/download/v3.26.4/cmake-3.26.4-linux-x86_64.sh && \
21
+ chmod +x cmake-3.26.4-linux-x86_64.sh && \
22
+ ./cmake-3.26.4-linux-x86_64.sh --skip-license --prefix=/usr/local && \
23
+ rm cmake-3.26.4-linux-x86_64.sh
24
+
25
+ RUN wget https://bootstrap.pypa.io/get-pip.py && \
26
+ python3 get-pip.py && \
27
+ rm get-pip.py
28
+
29
+ RUN python3 -m pip install --upgrade pip && python3 -m pip install setuptools==65.0.1 wheel spacy==3.7.5
30
+ RUN python3 -m spacy download en_core_web_sm
31
+
32
+ RUN python3 -m pip install numpy==1.21.0
33
+ RUN python3 -m pip install scikit-learn==1.0.2 --prefer-binary
34
+
35
+ RUN apt-get install --no-install-recommends wget ffmpeg=7:* \
36
+ libsm6=2:* libxext6=2:* git=1:* vim=2:* -y \
37
+ && apt-get clean && apt-get autoremove && rm -rf /var/lib/apt/lists/*
38
+
39
+ RUN wget https://github.com/mikefarah/yq/releases/download/v4.30.8/yq_linux_amd64 -O /usr/bin/yq \
40
+ && chmod +x /usr/bin/yq
41
+
42
+ RUN pip install git+https://github.com/apple/ml-depth-pro.git
43
+ RUN pip install 'git+https://github.com/facebookresearch/sam2.git'
44
+ RUN pip install git+https://github.com/openai/CLIP.git
45
+
46
+ RUN pip install --upgrade torch==2.4.0+cu118 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu118
47
+
48
+ COPY . /app
49
+ RUN pip install -r requirements.txt
50
+ RUN pip uninstall -y flash_attn
51
+ RUN pip install git+https://github.com/Dao-AILab/flash-attention.git@v2.6.0
52
+
53
+ RUN pip uninstall -y onnxruntime onnxruntime-gpu
54
+ RUN pip install onnxruntime-gpu==1.18.1
55
+
56
+ # Expose the port Gradio will run on
57
+ EXPOSE 7860
58
+
59
+ # Run the Gradio app
60
+ CMD ["python3", "app.py"]
61
+
app.py CHANGED
@@ -1,203 +1,464 @@
1
- import gradio as gr
2
- import spaces
3
  import os
4
- import time
 
 
 
 
5
  from PIL import Image
6
- import functools
7
- from models.mllava import MLlavaProcessor, LlavaForConditionalGeneration, chat_mllava_stream, MLlavaForConditionalGeneration, chat_mllava
8
- from models.conversation import conv_templates
9
- from typing import List
10
- processor = MLlavaProcessor.from_pretrained("remyxai/SpaceMantis")
11
- model = LlavaForConditionalGeneration.from_pretrained("remyxai/SpaceMantis")
12
- conv_template = conv_templates['llama_3']
13
-
14
- @spaces.GPU
15
- def generate_stream(text:str, images:List[Image.Image], history: List[dict], **kwargs):
16
- global processor, model
17
- model = model.to("cuda")
18
- if not images:
19
- images = None
20
- for text, history in chat_mllava_stream(text, images, model, processor, history=history, **kwargs):
21
- yield text
22
-
23
- return text
24
-
25
- @spaces.GPU
26
- def generate(text:str, images:List[Image.Image], history: List[dict], **kwargs):
27
- global processor, model
28
- model = model.to("cuda")
29
- if not images:
30
- images = None
31
- generated_text, history = chat_mllava(text, images, model, processor, history=history, **kwargs)
32
- return generated_text
33
-
34
- def enable_next_image(uploaded_images, image):
35
- uploaded_images.append(image)
36
- return uploaded_images, gr.MultimodalTextbox(value=None, interactive=False)
37
-
38
- def add_message(history, message):
39
- if message["files"]:
40
- for file in message["files"]:
41
- history.append([(file,), None])
42
- if message["text"]:
43
- history.append([message["text"], None])
44
- return history, gr.MultimodalTextbox(value=None)
45
-
46
- def print_like_dislike(x: gr.LikeData):
47
- print(x.index, x.value, x.liked)
48
-
49
-
50
- def get_chat_history(history):
51
- chat_history = []
52
- user_role = conv_template.roles[0]
53
- assistant_role = conv_template.roles[1]
54
- for i, message in enumerate(history):
55
- if isinstance(message[0], str):
56
- chat_history.append({"role": user_role, "text": message[0]})
57
- if i != len(history) - 1:
58
- assert message[1], "The bot message is not provided, internal error"
59
- chat_history.append({"role": assistant_role, "text": message[1]})
60
- else:
61
- assert not message[1], "the bot message internal error, get: {}".format(message[1])
62
- chat_history.append({"role": assistant_role, "text": ""})
63
- return chat_history
64
-
65
-
66
- def get_chat_images(history):
67
- images = []
68
- for message in history:
69
- if isinstance(message[0], tuple):
70
- images.extend(message[0])
71
- return images
72
-
73
-
74
- def bot(history):
75
- print(history)
76
- cur_messages = {"text": "", "images": []}
77
- for message in history[::-1]:
78
- if message[1]:
79
- break
80
- if isinstance(message[0], str):
81
- cur_messages["text"] = message[0] + " " + cur_messages["text"]
82
- elif isinstance(message[0], tuple):
83
- cur_messages["images"].extend(message[0])
84
- cur_messages["text"] = cur_messages["text"].strip()
85
- cur_messages["images"] = cur_messages["images"][::-1]
86
- if not cur_messages["text"]:
87
- raise gr.Error("Please enter a message")
88
- if cur_messages['text'].count("<image>") < len(cur_messages['images']):
89
- gr.Warning("The number of images uploaded is more than the number of <image> placeholders in the text. Will automatically prepend <image> to the text.")
90
- cur_messages['text'] = "<image> "* (len(cur_messages['images']) - cur_messages['text'].count("<image>")) + cur_messages['text']
91
- history[-1][0] = cur_messages["text"]
92
- if cur_messages['text'].count("<image>") > len(cur_messages['images']):
93
- gr.Warning("The number of images uploaded is less than the number of <image> placeholders in the text. Will automatically remove extra <image> placeholders from the text.")
94
- cur_messages['text'] = cur_messages['text'][::-1].replace("<image>"[::-1], "", cur_messages['text'].count("<image>") - len(cur_messages['images']))[::-1]
95
- history[-1][0] = cur_messages["text"]
96
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
 
 
98
 
99
- chat_history = get_chat_history(history)
100
- chat_images = get_chat_images(history)
 
101
 
102
- generation_kwargs = {
103
- "max_new_tokens": 4096,
104
- "num_beams": 1,
105
- "do_sample": False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- response = generate_stream(None, chat_images, chat_history, **generation_kwargs)
109
- for _output in response:
110
- history[-1][1] = _output
111
- time.sleep(0.05)
112
- yield history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
 
115
-
116
  def build_demo():
117
  with gr.Blocks() as demo:
118
-
119
- gr.Markdown(""" # SpaceMantis
120
- Mantis is a multimodal conversational AI model fine-tuned from [Mantis-8B-siglip-llama3](https://huggingface.co/remyxai/SpaceMantis/blob/main/TIGER-Lab/Mantis-8B-siglip-llama3) for enhanced spatial reasoning. It's optimized for multi-image reasoning, where inverleaved text and images can be used to generate responses.
 
121
 
122
- ### [Github](https://github.com/remyxai/VQASynth) | [Model](https://huggingface.co/remyxai/SpaceMantis) | [Dataset](https://huggingface.co/datasets/remyxai/mantis-spacellava)
123
  """)
124
-
125
- gr.Markdown("""## Chat with SpaceMantis
126
- SpaceMantis supports interleaved text-image input format, where you can simply use the placeholder `<image>` to indicate the position of uploaded images.
127
- The model is optimized for multi-image reasoning, while preserving the ability to chat about text and images in a single conversation.
128
- (The model currently serving is [🤗 remyxai/SpaceMantis](https://huggingface.co/remyxai/SpaceMantis))
129
  """)
130
-
131
- chatbot = gr.Chatbot(line_breaks=True)
132
- chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload images. Please use <image> to indicate the position of uploaded images", show_label=True)
133
-
134
- chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
135
-
136
- """
137
- with gr.Accordion(label='Advanced options', open=False):
138
- temperature = gr.Slider(
139
- label='Temperature',
140
- minimum=0.1,
141
- maximum=2.0,
142
- step=0.1,
143
- value=0.2,
144
- interactive=True
145
- )
146
- top_p = gr.Slider(
147
- label='Top-p',
148
- minimum=0.05,
149
- maximum=1.0,
150
- step=0.05,
151
- value=1.0,
152
- interactive=True
153
- )
154
- """
155
-
156
- bot_msg = chat_msg.success(bot, chatbot, chatbot, api_name="bot_response")
157
-
158
- chatbot.like(print_like_dislike, None, None)
159
 
160
  with gr.Row():
161
- send_button = gr.Button("Send")
162
- clear_button = gr.ClearButton([chatbot, chat_input])
 
 
 
 
 
 
 
 
 
163
 
164
- send_button.click(
165
- add_message, [chatbot, chat_input], [chatbot, chat_input]
166
- ).then(
167
- bot, chatbot, chatbot, api_name="bot_response"
 
168
  )
169
-
 
170
  gr.Examples(
171
  examples=[
172
- {
173
- "text": "Give me the height of the man in the red hat in feet.",
174
- "files": ["./examples/warehouse_rgb.jpg"]
175
- },
176
  ],
177
- inputs=[chat_input],
178
- )
179
-
 
 
 
180
  gr.Markdown("""
181
- ## Citation
182
- ```
183
- @article{chen2024spatialvlm,
184
- title = {SpatialVLM: Endowing Vision-Language Models with Spatial Reasoning Capabilities},
185
- author = {Chen, Boyuan and Xu, Zhuo and Kirmani, Sean and Ichter, Brian and Driess, Danny and Florence, Pete and Sadigh, Dorsa and Guibas, Leonidas and Xia, Fei},
186
- journal = {arXiv preprint arXiv:2401.12168},
187
- year = {2024},
188
- url = {https://arxiv.org/abs/2401.12168},
189
- }
190
-
191
- @article{jiang2024mantis,
192
- title={MANTIS: Interleaved Multi-Image Instruction Tuning},
193
- author={Jiang, Dongfu and He, Xuan and Zeng, Huaye and Wei, Con and Ku, Max and Liu, Qian and Chen, Wenhu},
194
- journal={arXiv preprint arXiv:2405.01483},
195
- year={2024}
196
- }
197
- ```""")
198
- return demo
199
-
200
 
201
- if __name__ == "__main__":
202
  demo = build_demo()
203
- demo.launch()
 
 
 
1
  import os
2
+ import sys
3
+ import uuid
4
+ import torch
5
+ import random
6
+ import numpy as np
7
  from PIL import Image
8
+ import open3d as o3d
9
+ import matplotlib.pyplot as plt
10
+
11
+ from transformers import AutoProcessor, AutoModelForCausalLM
12
+ from transformers import SamModel, SamProcessor
13
+
14
+ import depth_pro
15
+
16
+ import spacy
17
+ import gradio as gr
18
+
19
+ nlp = spacy.load("en_core_web_sm")
20
+
21
+ def find_subject(doc):
22
+ for token in doc:
23
+ # Check if the token is a subject
24
+ if "subj" in token.dep_:
25
+ return token.text, token.head
26
+ return None, None
27
+
28
+ def extract_descriptions(doc, head):
29
+ descriptions = []
30
+ for chunk in doc.noun_chunks:
31
+ # Check if the chunk is directly related to the subject's verb or is an attribute
32
+ if chunk.root.head == head or chunk.root.dep_ == 'attr':
33
+ descriptions.append(chunk.text)
34
+ return descriptions
35
+
36
+ def caption_refiner(caption):
37
+ doc = nlp(caption)
38
+ subject, action_verb = find_subject(doc)
39
+ if action_verb:
40
+ descriptions = extract_descriptions(doc, action_verb)
41
+ return ', '.join(descriptions)
42
+ else:
43
+ return caption
44
+
45
+ def sam2(image, input_boxes, model_id="facebook/sam-vit-base"):
46
+ device = "cuda" if torch.cuda.is_available() else "cpu"
47
+ model = SamModel.from_pretrained(model_id).to(device)
48
+ processor = SamProcessor.from_pretrained(model_id)
49
+ inputs = processor(image, input_boxes=[[input_boxes]], return_tensors="pt").to(device)
50
+ with torch.no_grad():
51
+ outputs = model(**inputs)
52
+
53
+ masks = processor.image_processor.post_process_masks(
54
+ outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
55
+ )
56
+ return masks
57
+
58
+ def load_florence2(model_id="microsoft/Florence-2-base-ft", device='cuda'):
59
+ torch_dtype = torch.float16 if device == 'cuda' else torch.float32
60
+ florence_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, trust_remote_code=True).to(device)
61
+ florence_processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
62
+ return florence_model, florence_processor
63
+
64
+ def florence2(image, prompt="", task="<OD>"):
65
+ device = florence_model.device
66
+ torch_dtype = florence_model.dtype
67
+ inputs = florence_processor(text=task + prompt, images=image, return_tensors="pt").to(device, torch_dtype)
68
+ generated_ids = florence_model.generate(
69
+ input_ids=inputs["input_ids"],
70
+ pixel_values=inputs["pixel_values"],
71
+ max_new_tokens=1024,
72
+ num_beams=3,
73
+ do_sample=False
74
+ )
75
+ generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
76
+ parsed_answer = florence_processor.post_process_generation(generated_text, task=task, image_size=(image.width, image.height))
77
+ return parsed_answer[task]
78
+
79
+
80
+ # Load and preprocess an image.
81
+ def depth_estimation(image_path):
82
+ model.eval()
83
+ image, _, f_px = depth_pro.load_rgb(image_path)
84
+ image = transform(image)
85
+
86
+ # Run inference.
87
+ prediction = model.infer(image, f_px=f_px)
88
+ depth = prediction["depth"] # Depth in [m].
89
+ focallength_px = prediction["focallength_px"] # Focal length in pixels.
90
+ depth = depth.cpu().numpy()
91
+ return depth, focallength_px
92
+
93
+
94
+ def create_point_cloud_from_rgbd(rgb, depth, intrinsic_parameters):
95
+ rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(
96
+ o3d.geometry.Image(rgb),
97
+ o3d.geometry.Image(depth),
98
+ depth_scale=10.0,
99
+ depth_trunc=100.0,
100
+ convert_rgb_to_intensity=False
101
+ )
102
+ intrinsic = o3d.camera.PinholeCameraIntrinsic()
103
+ intrinsic.set_intrinsics(intrinsic_parameters['width'], intrinsic_parameters['height'],
104
+ intrinsic_parameters['fx'], intrinsic_parameters['fy'],
105
+ intrinsic_parameters['cx'], intrinsic_parameters['cy'])
106
+ pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd_image, intrinsic)
107
+ return pcd
108
+
109
+
110
+ def canonicalize_point_cloud(pcd, canonicalize_threshold=0.3):
111
+ # Segment the largest plane, assumed to be the floor
112
+ plane_model, inliers = pcd.segment_plane(distance_threshold=0.01, ransac_n=3, num_iterations=1000)
113
+
114
+ canonicalized = False
115
+ if len(inliers) / len(pcd.points) > canonicalize_threshold:
116
+ canonicalized = True
117
+
118
+ # Ensure the plane normal points upwards
119
+ if np.dot(plane_model[:3], [0, 1, 0]) < 0:
120
+ plane_model = -plane_model
121
+
122
+ # Normalize the plane normal vector
123
+ normal = plane_model[:3] / np.linalg.norm(plane_model[:3])
124
+
125
+ # Compute the new basis vectors
126
+ new_y = normal
127
+ new_x = np.cross(new_y, [0, 0, -1])
128
+ new_x /= np.linalg.norm(new_x)
129
+ new_z = np.cross(new_x, new_y)
130
+
131
+ # Create the transformation matrix
132
+ transformation = np.identity(4)
133
+ transformation[:3, :3] = np.vstack((new_x, new_y, new_z)).T
134
+ transformation[:3, 3] = -np.dot(transformation[:3, :3], pcd.points[inliers[0]])
135
+
136
+
137
+ # Apply the transformation
138
+ pcd.transform(transformation)
139
+
140
+ # Additional 180-degree rotation around the Z-axis
141
+ rotation_z_180 = np.array([[np.cos(np.pi), -np.sin(np.pi), 0],
142
+ [np.sin(np.pi), np.cos(np.pi), 0],
143
+ [0, 0, 1]])
144
+ pcd.rotate(rotation_z_180, center=(0, 0, 0))
145
+
146
+ return pcd, canonicalized, transformation
147
+ else:
148
+ return pcd, canonicalized, None
149
+
150
+
151
+ def compute_iou(box1, box2):
152
+ # Extract the coordinates
153
+ x1_min, y1_min, x1_max, y1_max = box1
154
+ x2_min, y2_min, x2_max, y2_max = box2
155
+
156
+ # Compute the intersection rectangle
157
+ x_inter_min = max(x1_min, x2_min)
158
+ y_inter_min = max(y1_min, y2_min)
159
+ x_inter_max = min(x1_max, x2_max)
160
+ y_inter_max = min(y1_max, y2_max)
161
+
162
+ # Intersection width and height
163
+ inter_width = max(0, x_inter_max - x_inter_min)
164
+ inter_height = max(0, y_inter_max - y_inter_min)
165
 
166
+ # Intersection area
167
+ inter_area = inter_width * inter_height
168
 
169
+ # Boxes areas
170
+ box1_area = (x1_max - x1_min) * (y1_max - y1_min)
171
+ box2_area = (x2_max - x2_min) * (y2_max - y2_min)
172
 
173
+ # Union area
174
+ union_area = box1_area + box2_area - inter_area
175
+
176
+ # Intersection over Union
177
+ iou = inter_area / union_area if union_area != 0 else 0
178
+
179
+ return iou
180
+
181
+
182
+ def human_like_distance(distance_meters, scale_factor=10):
183
+ # Define the choices with units included, focusing on the 0.1 to 10 meters range
184
+ distance_meters *= scale_factor
185
+ if distance_meters < 1: # For distances less than 1 meter
186
+ choices = [
187
+ (
188
+ round(distance_meters * 100, 2),
189
+ "centimeters",
190
+ 0.2,
191
+ ), # Centimeters for very small distances
192
+ (
193
+ round(distance_meters, 2),
194
+ "inches",
195
+ 0.8,
196
+ ), # Inches for the majority of cases under 1 meter
197
+ ]
198
+ elif distance_meters < 3: # For distances less than 3 meters
199
+ choices = [
200
+ (round(distance_meters, 2), "meters", 0.5),
201
+ (
202
+ round(distance_meters, 2),
203
+ "feet",
204
+ 0.5,
205
+ ), # Feet as a common unit within indoor spaces
206
+ ]
207
+ else: # For distances from 3 up to 10 meters
208
+ choices = [
209
+ (
210
+ round(distance_meters, 2),
211
+ "meters",
212
+ 0.7,
213
+ ), # Meters for clarity and international understanding
214
+ (
215
+ round(distance_meters, 2),
216
+ "feet",
217
+ 0.3,
218
+ ), # Feet for additional context
219
+ ]
220
+ # Normalize probabilities and make a selection
221
+ total_probability = sum(prob for _, _, prob in choices)
222
+ cumulative_distribution = []
223
+ cumulative_sum = 0
224
+ for value, unit, probability in choices:
225
+ cumulative_sum += probability / total_probability # Normalize probabilities
226
+ cumulative_distribution.append((cumulative_sum, value, unit))
227
+
228
+ # Randomly choose based on the cumulative distribution
229
+ r = random.random()
230
+ for cumulative_prob, value, unit in cumulative_distribution:
231
+ if r < cumulative_prob:
232
+ return f"{value} {unit}"
233
+
234
+ # Fallback to the last choice if something goes wrong
235
+ return f"{choices[-1][0]} {choices[-1][1]}"
236
+
237
+
238
+ def filter_bboxes(data, iou_threshold=0.5):
239
+ filtered_bboxes = []
240
+ filtered_labels = []
241
+
242
+ for i in range(len(data['bboxes'])):
243
+ current_box = data['bboxes'][i]
244
+ current_label = data['labels'][i]
245
+ is_duplicate = False
246
+
247
+ for j in range(len(filtered_bboxes)):
248
+ if current_label == filtered_labels[j]:# and compute_iou(current_box, filtered_bboxes[j]) > iou_threshold:
249
+ is_duplicate = True
250
+ break
251
+
252
+ if not is_duplicate:
253
+ filtered_bboxes.append(current_box)
254
+ filtered_labels.append(current_label)
255
+
256
+ return {'bboxes': filtered_bboxes, 'labels': filtered_labels, 'caption': data['caption']}
257
+
258
+ def process_image(image_path: str):
259
+ depth, fx = depth_estimation(image_path)
260
+
261
+ img = Image.open(image_path).convert('RGB')
262
+ width, height = img.size
263
+
264
+ description = florence2(img, task="<MORE_DETAILED_CAPTION>")
265
+ print(description)
266
+
267
+ regions = []
268
+ for cap in description.split('.'):
269
+ if cap:
270
+ roi = florence2(img, prompt=" " + cap, task="<CAPTION_TO_PHRASE_GROUNDING>")
271
+ roi["caption"] = caption_refiner(cap.lower())
272
+ roi = filter_bboxes(roi)
273
+ if len(roi['bboxes']) > 1:
274
+ flip = random.choice(['heads', 'tails'])
275
+ if flip == 'heads':
276
+ idx = random.randint(1, len(roi['bboxes']) - 1)
277
+ else:
278
+ idx = 0
279
+ if idx > 0: # test bbox IOU
280
+ roi['caption'] = roi['labels'][idx].lower() + ' with ' + roi['labels'][0].lower()
281
+ roi['bboxes'] = [roi['bboxes'][idx]]
282
+ roi['labels'] = [roi['labels'][idx]]
283
+
284
+ if roi['bboxes']:
285
+ regions.append(roi)
286
+ print(roi)
287
+
288
+ bboxes = [item['bboxes'][0] for item in regions]
289
+ n = len(bboxes)
290
+ distance_matrix = np.zeros((n, n))
291
+ for i in range(n):
292
+ for j in range(n):
293
+ if i != j:
294
+ distance_matrix[i][j] = 1 - compute_iou(bboxes[i], bboxes[j])
295
+
296
+ scores = np.sum(distance_matrix, axis=1)
297
+ selected_indices = np.argsort(scores)[-3:]
298
+ regions = [(regions[i]['bboxes'][0], regions[i]['caption']) for i in selected_indices][:2]
299
+
300
+ # Create point cloud
301
+ camera_intrinsics = intrinsic_parameters = {
302
+ 'width': width,
303
+ 'height': height,
304
+ 'fx': fx,
305
+ 'fy': fx * height / width,
306
+ 'cx': width / 2,
307
+ 'cy': height / 2,
308
  }
309
+
310
+ pcd = create_point_cloud_from_rgbd(np.array(img).copy(), depth, camera_intrinsics)
311
+ normed_pcd, canonicalized, transformation = canonicalize_point_cloud(pcd)
312
+
313
+
314
+ masks = []
315
+ for box, cap in regions:
316
+ masks.append((cap, sam2(img, box)))
317
+
318
+
319
+ point_clouds = []
320
+ for cap, mask in masks:
321
+ m = mask[0].numpy()[0].squeeze().transpose((1, 2, 0))
322
+ mask = np.any(m, axis=2)
323
+
324
+ try:
325
+ points = np.asarray(normed_pcd.points)
326
+ colors = np.asarray(normed_pcd.colors)
327
+ masked_points = points[mask.ravel()]
328
+ masked_colors = colors[mask.ravel()]
329
+
330
+ masked_point_cloud = o3d.geometry.PointCloud()
331
+ masked_point_cloud.points = o3d.utility.Vector3dVector(masked_points)
332
+ masked_point_cloud.colors = o3d.utility.Vector3dVector(masked_colors)
333
+
334
+ point_clouds.append((cap, masked_point_cloud))
335
+ except:
336
+ pass
337
+
338
+ boxes3D = []
339
+ centers = []
340
+ pcd = o3d.geometry.PointCloud()
341
+ for cap, pc in point_clouds[:2]:
342
+ cl, ind = pc.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0)
343
+ inlier_cloud = pc.select_by_index(ind)
344
+ pcd += inlier_cloud
345
+ obb = inlier_cloud.get_axis_aligned_bounding_box()
346
+ obb.color = (1, 0, 0)
347
+ centers.append(obb.get_center())
348
+ boxes3D.append(obb)
349
+
350
+
351
+ lines = [[0, 1]]
352
+ points = [centers[0], centers[1]]
353
+ distance = human_like_distance(np.asarray(point_clouds[0][1].compute_point_cloud_distance(point_clouds[-1][1])).mean())
354
+ text_output = "Distance between {} and {} is: {}".format(point_clouds[0][0], point_clouds[-1][0], distance)
355
+ print(text_output)
356
 
357
+ colors = [[1, 0, 0] for i in range(len(lines))] # Red color for lines
358
+ line_set = o3d.geometry.LineSet(
359
+ points=o3d.utility.Vector3dVector(points),
360
+ lines=o3d.utility.Vector2iVector(lines)
361
+ )
362
+ line_set.colors = o3d.utility.Vector3dVector(colors)
363
+
364
+ boxes3D.append(line_set)
365
+
366
+
367
+ uuid_out = str(uuid.uuid4())
368
+ ply_file = f"output_{uuid_out}.ply"
369
+ obj_file = f"output_{uuid_out}.obj"
370
+ o3d.io.write_point_cloud(ply_file, pcd)
371
+
372
+ mesh = o3d.io.read_triangle_mesh(ply_file)
373
+
374
+ o3d.io.write_triangle_mesh(obj_file, mesh)
375
+
376
+ return obj_file, text_output
377
+
378
+
379
+
380
+ def custom_draw_geometry_with_rotation(pcd):
381
+
382
+ def rotate_view(vis):
383
+ ctr = vis.get_view_control()
384
+ vis.get_render_option().background_color = [0, 0, 0]
385
+ ctr.rotate(1.0, 0.0)
386
+ # https://github.com/isl-org/Open3D/issues/1483
387
+ #parameters = o3d.io.read_pinhole_camera_parameters("ScreenCamera_2024-10-24-10-03-57.json")
388
+ #ctr.convert_from_pinhole_camera_parameters(parameters)
389
+ return False
390
+
391
+ o3d.visualization.draw_geometries_with_animation_callback([pcd] + boxes3D,
392
+ rotate_view)
393
 
394
 
 
395
  def build_demo():
396
  with gr.Blocks() as demo:
397
+ # Title and introductory Markdown
398
+ gr.Markdown("""
399
+ # Synthesizing SpatialVQA Samples with VQASynth
400
+ This space helps test the full [VQASynth](https://github.com/remyxai/VQASynth) scene reconstruction pipeline on a single image with visualizations.
401
 
402
+ ### [Github](https://github.com/remyxai/VQASynth) | [Collection](https://huggingface.co/collections/remyxai/spacevlms-66a3dbb924756d98e7aec678)
403
  """)
404
+
405
+ # Description for users
406
+ gr.Markdown("""
407
+ ## Instructions
408
+ Upload an image, and the tool will generate a corresponding 3D point cloud visualization of the objects found and an example prompt and response describing a spatial relationship between the objects.
409
  """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
 
411
  with gr.Row():
412
+ # Left Column: Inputs
413
+ with gr.Column():
414
+ # Image upload and processing button in the left column
415
+ image_input = gr.Image(type="filepath", label="Upload an Image")
416
+ generate_button = gr.Button("Generate")
417
+
418
+ # Right Column: Outputs
419
+ with gr.Column():
420
+ # 3D Model and Caption Outputs
421
+ model_output = gr.Model3D(label="3D Point Cloud") # Only used as output
422
+ caption_output = gr.Text(label="Caption")
423
 
424
+ # Link the button to process the image and display the outputs
425
+ generate_button.click(
426
+ process_image, # Your processing function
427
+ inputs=image_input,
428
+ outputs=[model_output, caption_output]
429
  )
430
+
431
+ # Examples section at the bottom
432
  gr.Examples(
433
  examples=[
434
+ ["./examples/warehouse_rgb.jpg"], ["./examples/spooky_doggy.png"], ["./examples/bee_and_flower.jpg"], ["./examples/road-through-dense-forest.jpg"], ["./examples/gears.png"] # Update with the path to your example image
 
 
 
435
  ],
436
+ inputs=image_input,
437
+ label="Example Images",
438
+ examples_per_page=5
439
+ )
440
+
441
+ # Citations
442
  gr.Markdown("""
443
+ ## Citation
444
+ ```
445
+ @article{chen2024spatialvlm,
446
+ title = {SpatialVLM: Endowing Vision-Language Models with Spatial Reasoning Capabilities},
447
+ author = {Chen, Boyuan and Xu, Zhuo and Kirmani, Sean and Ichter, Brian and Driess, Danny and Florence, Pete and Sadigh, Dorsa and Guibas, Leonidas and Xia, Fei},
448
+ journal = {arXiv preprint arXiv:2401.12168},
449
+ year = {2024},
450
+ url = {https://arxiv.org/abs/2401.12168},
451
+ }
452
+ ```
453
+ """)
454
+
455
+ return demo
456
+
457
+ if __name__ == "__main__":
458
+ global model, transform, florence_model, florence_processor
459
+ model, transform = depth_pro.create_model_and_transforms(device='cuda')
460
+ florence_model, florence_processor = load_florence2(device='cuda')
461
+
462
 
 
463
  demo = build_demo()
464
+ demo.launch(share=True)
checkpoints/depth_pro.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3eb35ca68168ad3d14cb150f8947a4edf85589941661fdb2686259c80685c0ce
3
+ size 1904446787
examples/bee_and_flower.jpg ADDED
examples/gears.png ADDED
examples/road-through-dense-forest.jpg ADDED
examples/spooky_doggy.png ADDED
requirements.txt CHANGED
@@ -2,5 +2,23 @@ torch
2
  transformers>=4.41.0
3
  Pillow
4
  gradio
5
- spaces
6
- multiprocess
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  transformers>=4.41.0
3
  Pillow
4
  gradio
5
+ accelerate==0.34.2
6
+ numpy==1.26.4
7
+ timm==1.0.9
8
+ einops==0.7.0
9
+ open3d==0.18.0
10
+ opencv-python==4.7.0.72
11
+ tqdm==4.64.1
12
+ torchprofile==0.0.4
13
+ matplotlib==3.6.2
14
+ huggingface-hub==0.24.7
15
+ onnx==1.13.1
16
+ onnxruntime==1.14.1
17
+ onnxsim==0.4.35
18
+ scipy==1.12.0
19
+ litellm==1.25.2
20
+ pycocotools==2.0.6
21
+ onnxruntime-gpu==1.18.1
22
+ pandas
23
+ html5lib
24
+ datasets