SegVol / utils.py
BoyaWu10's picture
download support (#3)
a2f2ef1
raw
history blame
No virus
5.08 kB
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageEnhance, ImageDraw
import torch
import streamlit as st
from model.inference_cpu import inference_case
initial_rectangle = {
"version": "4.4.0",
'objects': [
{
"type": "rect",
"version": "4.4.0",
"originX": "left",
"originY": "top",
"left": 50,
"top": 50,
"width": 100,
"height": 100,
'fill': 'rgba(255, 165, 0, 0.3)',
'stroke': '#2909F1',
'strokeWidth': 3,
'strokeDashArray': None,
'strokeLineCap': 'butt',
'strokeDashOffset': 0,
'strokeLineJoin': 'miter',
'strokeUniform': True,
'strokeMiterLimit': 4,
'scaleX': 1,
'scaleY': 1,
'angle': 0,
'flipX': False,
'flipY': False,
'opacity': 1,
'shadow': None,
'visible': True,
'backgroundColor': '',
'fillRule':
'nonzero',
'paintFirst':
'fill',
'globalCompositeOperation': 'source-over',
'skewX': 0,
'skewY': 0,
'rx': 0,
'ry': 0
}
]
}
def run():
image = st.session_state.data_item["image"].float()
image_zoom_out = st.session_state.data_item["zoom_out_image"].float()
text_prompt = None
point_prompt = None
box_prompt = None
if st.session_state.use_text_prompt:
text_prompt = st.session_state.text_prompt
if st.session_state.use_point_prompt and len(st.session_state.points) > 0:
point_prompt = reflect_points_into_model(st.session_state.points)
if st.session_state.use_box_prompt:
box_prompt = reflect_box_into_model(st.session_state.rectangle_3Dbox)
inference_case.clear()
st.session_state.preds_3D, st.session_state.preds_3D_ori = inference_case(image, image_zoom_out,
text_prompt=text_prompt,
_point_prompt=point_prompt,
_box_prompt=box_prompt)
def reflect_box_into_model(box_3d):
z1, y1, x1, z2, y2, x2 = box_3d
x1_prompt = int(x1 * 256.0 / 325.0)
y1_prompt = int(y1 * 256.0 / 325.0)
z1_prompt = int(z1 * 32.0 / 325.0)
x2_prompt = int(x2 * 256.0 / 325.0)
y2_prompt = int(y2 * 256.0 / 325.0)
z2_prompt = int(z2 * 32.0 / 325.0)
return torch.tensor(np.array([z1_prompt, y1_prompt, x1_prompt, z2_prompt, y2_prompt, x2_prompt]))
def reflect_json_data_to_3D_box(json_data, view):
if view == 'xy':
st.session_state.rectangle_3Dbox[1] = json_data['objects'][0]['top']
st.session_state.rectangle_3Dbox[2] = json_data['objects'][0]['left']
st.session_state.rectangle_3Dbox[4] = json_data['objects'][0]['top'] + json_data['objects'][0]['height'] * json_data['objects'][0]['scaleY']
st.session_state.rectangle_3Dbox[5] = json_data['objects'][0]['left'] + json_data['objects'][0]['width'] * json_data['objects'][0]['scaleX']
print(st.session_state.rectangle_3Dbox)
def reflect_points_into_model(points):
points_prompt_list = []
for point in points:
z, y, x = point
x_prompt = int(x * 256.0 / 325.0)
y_prompt = int(y * 256.0 / 325.0)
z_prompt = int(z * 32.0 / 325.0)
points_prompt_list.append([z_prompt, y_prompt, x_prompt])
points_prompt = np.array(points_prompt_list)
points_label = np.ones(points_prompt.shape[0])
print(points_prompt, points_label)
return (torch.tensor(points_prompt), torch.tensor(points_label))
def show_points(points_ax, points_label, ax):
color = 'red' if points_label == 0 else 'blue'
ax.scatter(points_ax[0], points_ax[1], c=color, marker='o', s=200)
def make_fig(image, preds, point_axs=None, current_idx=None, view=None):
# Convert A to an image
image = Image.fromarray((image * 255).astype(np.uint8)).convert("RGB")
enhancer = ImageEnhance.Contrast(image)
image = enhancer.enhance(2.0)
# Create a yellow mask from B
if preds is not None:
mask = np.where(preds == 1, 255, 0).astype(np.uint8)
mask = Image.merge("RGB",
(Image.fromarray(mask),
Image.fromarray(mask),
Image.fromarray(np.zeros_like(mask, dtype=np.uint8))))
# Overlay the mask on the image
image = Image.blend(image.convert("RGB"), mask, alpha=st.session_state.transparency)
if point_axs is not None:
draw = ImageDraw.Draw(image)
radius = 5
for point in point_axs:
z, y, x = point
if view == 'xy' and z == current_idx:
draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill="blue")
elif view == 'xz'and y == current_idx:
draw.ellipse((x-radius, z-radius, x+radius, z+radius), fill="blue")
return image