SegVol / utils.py
yuxin
support download
89939b6
raw
history blame
No virus
5.08 kB
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageEnhance, ImageDraw
import torch
import streamlit as st
from model.inference_cpu import inference_case
initial_rectangle = {
"version": "4.4.0",
'objects': [
{
"type": "rect",
"version": "4.4.0",
"originX": "left",
"originY": "top",
"left": 50,
"top": 50,
"width": 100,
"height": 100,
'fill': 'rgba(255, 165, 0, 0.3)',
'stroke': '#2909F1',
'strokeWidth': 3,
'strokeDashArray': None,
'strokeLineCap': 'butt',
'strokeDashOffset': 0,
'strokeLineJoin': 'miter',
'strokeUniform': True,
'strokeMiterLimit': 4,
'scaleX': 1,
'scaleY': 1,
'angle': 0,
'flipX': False,
'flipY': False,
'opacity': 1,
'shadow': None,
'visible': True,
'backgroundColor': '',
'fillRule':
'nonzero',
'paintFirst':
'fill',
'globalCompositeOperation': 'source-over',
'skewX': 0,
'skewY': 0,
'rx': 0,
'ry': 0
}
]
}
def run():
image = st.session_state.data_item["image"].float()
image_zoom_out = st.session_state.data_item["zoom_out_image"].float()
text_prompt = None
point_prompt = None
box_prompt = None
if st.session_state.use_text_prompt:
text_prompt = st.session_state.text_prompt
if st.session_state.use_point_prompt and len(st.session_state.points) > 0:
point_prompt = reflect_points_into_model(st.session_state.points)
if st.session_state.use_box_prompt:
box_prompt = reflect_box_into_model(st.session_state.rectangle_3Dbox)
inference_case.clear()
st.session_state.preds_3D, st.session_state.preds_3D_ori = inference_case(image, image_zoom_out,
text_prompt=text_prompt,
_point_prompt=point_prompt,
_box_prompt=box_prompt)
def reflect_box_into_model(box_3d):
z1, y1, x1, z2, y2, x2 = box_3d
x1_prompt = int(x1 * 256.0 / 325.0)
y1_prompt = int(y1 * 256.0 / 325.0)
z1_prompt = int(z1 * 32.0 / 325.0)
x2_prompt = int(x2 * 256.0 / 325.0)
y2_prompt = int(y2 * 256.0 / 325.0)
z2_prompt = int(z2 * 32.0 / 325.0)
return torch.tensor(np.array([z1_prompt, y1_prompt, x1_prompt, z2_prompt, y2_prompt, x2_prompt]))
def reflect_json_data_to_3D_box(json_data, view):
if view == 'xy':
st.session_state.rectangle_3Dbox[1] = json_data['objects'][0]['top']
st.session_state.rectangle_3Dbox[2] = json_data['objects'][0]['left']
st.session_state.rectangle_3Dbox[4] = json_data['objects'][0]['top'] + json_data['objects'][0]['height'] * json_data['objects'][0]['scaleY']
st.session_state.rectangle_3Dbox[5] = json_data['objects'][0]['left'] + json_data['objects'][0]['width'] * json_data['objects'][0]['scaleX']
print(st.session_state.rectangle_3Dbox)
def reflect_points_into_model(points):
points_prompt_list = []
for point in points:
z, y, x = point
x_prompt = int(x * 256.0 / 325.0)
y_prompt = int(y * 256.0 / 325.0)
z_prompt = int(z * 32.0 / 325.0)
points_prompt_list.append([z_prompt, y_prompt, x_prompt])
points_prompt = np.array(points_prompt_list)
points_label = np.ones(points_prompt.shape[0])
print(points_prompt, points_label)
return (torch.tensor(points_prompt), torch.tensor(points_label))
def show_points(points_ax, points_label, ax):
color = 'red' if points_label == 0 else 'blue'
ax.scatter(points_ax[0], points_ax[1], c=color, marker='o', s=200)
def make_fig(image, preds, point_axs=None, current_idx=None, view=None):
# Convert A to an image
image = Image.fromarray((image * 255).astype(np.uint8)).convert("RGB")
enhancer = ImageEnhance.Contrast(image)
image = enhancer.enhance(2.0)
# Create a yellow mask from B
if preds is not None:
mask = np.where(preds == 1, 255, 0).astype(np.uint8)
mask = Image.merge("RGB",
(Image.fromarray(mask),
Image.fromarray(mask),
Image.fromarray(np.zeros_like(mask, dtype=np.uint8))))
# Overlay the mask on the image
image = Image.blend(image.convert("RGB"), mask, alpha=st.session_state.transparency)
if point_axs is not None:
draw = ImageDraw.Draw(image)
radius = 5
for point in point_axs:
z, y, x = point
if view == 'xy' and z == current_idx:
draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill="blue")
elif view == 'xz'and y == current_idx:
draw.ellipse((x-radius, z-radius, x+radius, z+radius), fill="blue")
return image