DveloperY0115's picture
init repo
801501a
import torch
import vtk
from utils import files_utils
from custom_types import *
import functools
import matplotlib.pyplot as plt
bg_source_color = (152, 181, 234)
bg_target_color = (250, 200, 152)
button_color = (255, 0, 255)
bg_menu_color = (214, 139, 202)
bg_stage_color = (255, 180, 110)
default_colors = [(82, 108, 255), (160, 82, 255), (255, 43, 43), (255, 246, 79),
(153, 227, 107), (58, 186, 92), (8, 243, 255), (240, 136, 0)]
class SmoothingMethod(Enum):
Laplace = "laplace"
Taubin = "taubin"
class EditType(enum.Enum):
Pondering = 'pondering'
Translating = 'translating'
Rotating = 'rotating'
Scaling = 'scaling'
Marking = 'marking'
class EditDirection(enum.Enum):
X_Axis = 'axis_x'
Y_Axis = 'axis_y'
Z_Axis = 'axis_z'
palette = (
(63, 72, 204),
(51, 213, 73),
(213, 51, 159),
(153, 227, 107),
(246, 162, 81)
)
# palette = [(.6196, 0.0039, 0.2588),
# (.6873, 0.0790, 0.2748),
# (.7549, 0.1540, 0.2908),
# (.8226, 0.2291, 0.3068),
# (.8710, 0.2973, 0.2960),
# (.9092, 0.3552, 0.2812),
# (.9473, 0.4130, 0.2664),
# (.9652, 0.4874, 0.2904),
# (.9776, 0.5774, 0.3319),
# (.9887, 0.6574, 0.3689),
# (.9930, 0.7246, 0.4159),
# (.9942, 0.7862, 0.4676),
# (.9956, 0.8554, 0.5257),
# (.9968, 0.9023, 0.5851),
# (.9981, 0.9404, 0.6491),
# (.9993, 0.9785, 0.7130),
# (.9827, 0.9931, 0.7220),
# (.9519, 0.9808, 0.6740),
# (.9212, 0.9685, 0.6261),
# (.8747, 0.9497, 0.6016),
# (.7931, 0.9165, 0.6182),
# (.7205, 0.8870, 0.6330),
# (.6441, 0.8563, 0.6435),
# (.5592, 0.8231, 0.6448),
# (.4637, 0.7857, 0.6461),
# (.3840, 0.7429, 0.6544),
# (.3200, 0.6716, 0.6840),
# (.2561, 0.6002, 0.7135),
# (.2062, 0.5202, 0.7349),
# (.2604, 0.4501, 0.7017),
# (.3145, 0.3799, 0.6685),
# (.3686, 0.3098, 0.6353)]
RGB_COLOR = Union[Tuple[int, int, int], List[int]]
RGB_FLOAT_COLOR = Union[Tuple[float, float, float], List[float]]
RGBA_COLOR = Union[Tuple[int, int, int, int], List[int]]
RGBA_FLOAT_COLOR = Union[Tuple[float, float, float, float], List[float]]
def channel_to_float(*channel: int):
if type(channel[0]) is float and 0 <= channel[0] <= 1:
return channel
return [c / 255. for c in channel]
def rgb_to_float(*colors: RGB_COLOR) -> Union[RGB_FLOAT_COLOR, List[RGB_FLOAT_COLOR]]:
float_colors = [channel_to_float(*c) for c in colors]
if len(float_colors) == 1:
return float_colors[0]
return float_colors
def rgb_to_rgba_float(color: RGB_COLOR, alpha: float) -> RGBA_FLOAT_COLOR:
color = list(rgb_to_float(color)) + [alpha]
return color
class Buttons(enum.Enum):
translate = 'T'
rotate = 'R'
stretch = 'S'
reset = 'reset'
update = 'hq'
symmetric = 'symmetric'
empty = -1
class ViewStyle:
def __init__(self, base_color: RGB_COLOR, included_color: RGB_COLOR, selected_color: RGB_COLOR,
opacity: float):
self.base_color = rgb_to_float(base_color)
self.included_color = rgb_to_float(included_color)
self.stroke_color = list(selected_color) + [200]
self.selected_color = rgb_to_float(selected_color)
self.opacity = opacity
class Transition:
def __init__(self, transition_origin: ARRAY, transition_type: EditType):
self.transition_origin: ARRAY = transition_origin
self.transition_type: ARRAY = transition_type
self.translation: ARRAY = np.zeros(3)
self.rotation: ARRAY = np.eye(3)
@functools.lru_cache(10)
def get_rotation_matrix(theta: float, axis: float) -> ARRAY:
rotate_mat = np.eye(3)
rotate_mat[axis, axis] = 1
cos_theta, sin_theta = np.cos(theta), np.sin(theta)
rotate_mat[(axis + 1) % 3, (axis + 1) % 3] = cos_theta
rotate_mat[(axis + 2) % 3, (axis + 2) % 3] = cos_theta
rotate_mat[(axis + 1) % 3, (axis + 2) % 3] = sin_theta
rotate_mat[(axis + 2) % 3, (axis + 1) % 3] = -sin_theta
return rotate_mat
def load_vtk(path: str, vtk_reader):
vtk_reader.SetFileName(path)
vtk_reader.Update()
source = vtk_reader.GetOutput()
return source
def save_vtk(data, path: str, vtk_writer):
vtk_writer.SetFileName(path)
vtk_writer.SetInputData(data)
vtk_writer.Update()
vtk_writer.Write()
def load_vtk_obj(path: str):
path = files_utils.add_suffix(path, ".obj")
return load_vtk(path, vtk.vtkOBJReader())
def save_vtk_image(data, path: str):
path = files_utils.add_suffix(path, ".vtk")
files_utils.init_folders(path)
save_vtk(data, path, vtk.vtkXMLImageDataWriter())
def load_vtk_image(path: str) -> vtk.vtkImageData:
path = files_utils.add_suffix(path, ".vtk")
return load_vtk(path, vtk.vtkXMLImageDataReader())
def set_default_properties(actor: vtk.vtkActor, color: Tuple[float, float, float]):
properties = actor.GetProperty()
properties.SetPointSize(10)
properties.SetDiffuseColor(.6, .6, .6)
properties.SetAmbient(.2)
properties.SetDiffuse(.8)
properties.SetSpecular(.5)
properties.SetSpecularColor(.2, .2, .2)
properties.SetSpecularPower(30.0)
properties.SetColor(*color)
return actor
def wrap_mesh(source, color):
mapper = vtk.vtkPolyDataMapper()
mapper.SetInputData(source)
actor = vtk.vtkActor()
actor.SetMapper(mapper)
actor = set_default_properties(actor, color)
return actor, mapper
def create_vtk_image(path: str) -> vtk.vtkImageData:
root, name, _ = files_utils.split_path(path)
cache_image_path = f"{root}/cache/{name}.vtk"
if not files_utils.is_file(cache_image_path):
np_image = files_utils.load_image(path, 'RGBA')
image = vtk.vtkImageData()
image.SetDimensions(np_image.shape[1], np_image.shape[0], 1)
image.AllocateScalars(vtk.VTK_UNSIGNED_CHAR, np_image.shape[2])
dims = image.GetDimensions()
for y in range(dims[1]):
for x in range(dims[0]):
pixel = np_image[dims[1] - 1 - y, x]
for i in range(np_image.shape[2]):
image.SetScalarComponentFromDouble(x, y, 0, i, pixel[i])
# points = image.GetPointData().GetArray(0)
save_vtk_image(image, cache_image_path)
else:
image = load_vtk_image(cache_image_path)
return image
class ImageButton(vtk.vtkButtonWidget):
def process_state_change_event(self, obj, event):
print(f"end event {self.button_representation.GetState()}")
def set_size(self, window_size: Tuple[float, float]):
w, h = window_size[0] * self.full_size[0], window_size[1] * self.full_size[1]
pos_left, pos_top = int(w * self.position[0]), int(h * self.position[1])
position_coords = [pos_left,
pos_left + int(w * self.size[0]),
pos_top - int(h *self.size[1]),
pos_top,
0, 0]
self.button_representation.PlaceWidget(position_coords)
def resize_event(self, obj, event):
self.set_size(obj.GetSize())
def __init__(self, images_paths: List[str], interactor, render, size: Union[float, Tuple[float, float]],
position: Tuple[float, float], on_click: Optional[Callable[[Any, Any], None]] = None,
full_size: Tuple[float, float] = (1., 1.)):
super(ImageButton, self).__init__()
self.SetCurrentRenderer(render)
if type(size) is float:
size = (size, size)
self.full_size = full_size
render_window: vtk.vtkRenderWindow = interactor.GetRenderWindow()
images = map(lambda x: create_vtk_image(x), images_paths)
self.button_representation = vtk.vtkTexturedButtonRepresentation2D()
self.button_representation.SetNumberOfStates(len(images_paths))
self.button_representation.GetProperty().SetColor(1, 1, 1)
for i, image in enumerate(images):
self.button_representation.SetButtonTexture(i, image)
self.SetInteractor(interactor)
self.SetRepresentation(self.button_representation)
self.size = size
self.position = position
self.button_representation.SetPlaceFactor(1)
self.set_size(render_window.GetSize())
render_window.AddObserver(vtk.vtkCommand.WindowResizeEvent, self.resize_event)
if on_click is not None:
self.AddObserver(vtk.vtkCommand.StateChangedEvent, on_click)
self.On()
selection_prop = self.button_representation.GetSelectingProperty()
selection_prop.SetLineWidth(0.)
selection_prop.SetColor(1., 1., 1.)
def make_slider(iren, observer):
to_show = False
if to_show:
ren_left = vtk.vtkRenderer()
ren_left.SetBackground(*rgb_to_float((250, 255, 255)))
ren_window = vtk.vtkRenderWindow()
ren_window.AddRenderer(ren_left)
iren = vtk.vtkRenderWindowInteractor()
iren.SetRenderWindow(ren_window)
ren_window.Render()
slider_repres = vtk.vtkSliderRepresentation2D()
slider_repres.SetMinimumValue(0)
slider_repres.SetMaximumValue(100.)
# slider_repres.SetTitleText('Mesh\nOpacity')
slider_repres.SetValue(30.)
slider_repres.GetSliderProperty().SetColor(*rgb_to_float(bg_target_color))
slider_repres.ShowSliderLabelOff()
# slider_repres.GetLabelProperty().SetColor(1., 0., 0.)
slider_repres.GetCapProperty().SetColor(*rgb_to_float(bg_menu_color))
slider_repres.GetSelectedProperty().SetColor(1., 0., 0)
slider_repres.GetTubeProperty().SetColor(*rgb_to_float(bg_source_color))
slider_repres.GetPoint1Coordinate().SetCoordinateSystemToNormalizedDisplay()
slider_repres.GetPoint1Coordinate().SetValue(0.01, 0.1)
slider_repres.GetPoint2Coordinate().SetCoordinateSystemToNormalizedDisplay()
slider_repres.GetPoint2Coordinate().SetValue(0.23, 0.1)
slider_repres.SetSliderLength(0.01)
slider_repres.SetSliderWidth(0.01)
slider_repres.SetEndCapLength(0.01)
slider_repres.SetEndCapWidth(0.01)
slider_repres.SetTubeWidth(0.01)
slider_repres.SetLabelFormat('%f')
slider_widget = vtk.vtkSliderWidget()
slider_widget.SetInteractor(iren)
slider_widget.SetRepresentation(slider_repres)
slider_widget.KeyPressActivationOff()
slider_widget.SetAnimationModeToAnimate()
slider_widget.SetEnabled(True)
slider_widget.AddObserver('InteractionEvent', observer)
slider_widget.EnabledOn()
if to_show:
iren.Initialize()
ren_window.Render()
iren.Start()
del iren
del ren_window
return slider_widget, slider_repres
class CanvasRender(vtk.vtkRenderer):
@property
def origin_x(self):
return self.viewport_ren[0]
@property
def origin_y(self):
return self.viewport_ren[1]
@property
def width(self):
return self.viewport_ren[2] - self.viewport_ren[0]
@property
def height(self):
return self.viewport_ren[3] - self.viewport_ren[1]
def translate_point(self, pt: Tuple[int, int]) -> Tuple[int, int]:
return pt[0] - self.origin_x, pt[1] - self.origin_y
def get_mid_points(self, pt: Tuple[int, int]) -> List[List[int]]:
if self.last_point is None:
return []
pt_a, pt_b = torch.tensor(pt, dtype=torch.float32), torch.tensor(self.last_point, dtype=torch.float32)
delta = pt_b - pt_a
num_mids = max(int(delta.norm(2, 0).item() / 10), 2)
# num_mids = 4
mid_points = pt_a[None, :] + torch.linspace(0, 1, num_mids)[:, None] * delta[None, :]
mid_points[:, 0] += self.origin_x
mid_points[:, 1] += self.origin_y
return mid_points[:-1].long().tolist()
def draw(self, pt: Tuple[int, int], stroke_width: float = 5.) -> List[List[int]]:
pt = self.translate_point(pt)
if self.last_point is not None:
self.canvas.FillTube(*self.last_point, *pt, stroke_width)
self.canvas.Update()
mid_points = self.get_mid_points(pt)
self.last_point = pt
return mid_points
def clear(self):
self.last_point = None
self.canvas.SetDrawColor(0, 0, 0, 0)
self.canvas.FillBox(0, self.width, 0, self.height)
self.canvas.SetDrawColor(*self.stroke_color)
self.canvas.Update()
def resize_event_(self, obj):
self.viewport_ren = self.set_int_viewport(obj.GetSize())
self.canvas.SetExtent(0, self.width, 0, self.height, 0, 0)
self.canvas.Update()
self.clear()
self.set_camera()
def resize_event(self, obj, event):
self.resize_event_(obj)
def set_camera(self):
origin = self.image_data.GetOrigin()
spacing = self.image_data.GetSpacing()
extent = self.image_data.GetExtent()
camera = self.canvas_render.GetActiveCamera()
camera.ParallelProjectionOn()
xc = origin[0] + 0.5 * (extent[0] + extent[1]) * spacing[0]
yc = origin[1] + 0.5 * (extent[2] + extent[3]) * spacing[1]
# xd = (extent[1] - extent[0] + 1) * spacing[0]
yd = (extent[3] - extent[2] + 1) * spacing[1]
d = camera.GetDistance()
camera.SetParallelScale(0.5 * yd)
camera.SetFocalPoint(xc, yc, 0.0)
camera.SetPosition(xc, yc, d)
def set_int_viewport(self, win_size) -> Tuple[int, int, int, int]:
w, h = win_size
return int(self.viewport[0] * w), int(self.viewport[1] * h), int(self.viewport[2] * w), int(self.viewport[3] * h)
def init_canvas(self):
self.canvas.SetExtent(0, self.width, 0, self.height, 0, 0)
self.canvas.PropagateUpdateExtent()
self.canvas.UpdateExtent((0, self.width, 0, self.height, 0, 0))
self.canvas.SetScalarTypeToUnsignedChar()
self.canvas.SetNumberOfScalarComponents(4)
self.set_brush(True)
image_data = self.canvas.GetOutput()
image_actor = vtk.vtkImageActor()
image_actor.SetInputData(image_data)
self.canvas_render.AddActor(image_actor)
return image_data
def set_brush(self, is_draw: bool):
self.is_draw = is_draw
self.stroke_color = self.base_stroke_color if is_draw else (255, 255, 255, 200)
# (*bg_menu_color, 150)
self.canvas.SetDrawColor(*self.stroke_color)
self.canvas.Update()
def change_brush(self, stroke_color):
self.base_stroke_color = stroke_color
self.set_brush(self.is_draw)
def __init__(self, viewport: Tuple[float, float, float, float], render_window: vtk.vtkRenderWindow,
bg_color: RGB_COLOR, stroke_color: Optional[RGBA_COLOR] = None):
super(CanvasRender, self).__init__()
self.SetViewport(*viewport)
self.viewport = viewport
self.canvas_render = vtk.vtkRenderer()
self.canvas_render.SetViewport(*viewport)
if stroke_color is None:
stroke_color = vtk.vtkNamedColors().GetColor4ub('LightCoral')
stroke_color = stroke_color.GetRed(), stroke_color.GetGreen(), stroke_color.GetBlue(), 200
self.base_stroke_color = self.stroke_color = stroke_color
# self.SetBackground(*bg_color)
self.is_draw = True
self.canvas_render.InteractiveOff()
self.viewport_ren = self.set_int_viewport(render_window.GetSize())
self.canvas = vtk.vtkImageCanvasSource2D()
self.image_data = self.init_canvas()
render_window.AddObserver(vtk.vtkCommand.WindowResizeEvent, self.resize_event)
self.last_point: Optional[Tuple[int, int]] = None
self.SetLayer(0)
self.canvas_render.SetLayer(1)
self.SetBackground(*rgb_to_float(bg_color))
render_window.AddRenderer(self)
render_window.AddRenderer(self.canvas_render)
self.set_camera()
def init_palettes(cmap='Spectral'):
colors = {}
color_map = plt.cm.get_cmap(cmap)
def get_palette(num_colors: int) -> T:
nonlocal colors, color_map
if num_colors == 1:
colors[num_colors] = torch.tensor([.45])
if num_colors not in colors:
colors[num_colors] = torch.tensor([color_map(float(idx) / (num_colors - 1)) for idx in range(num_colors)])
return colors[num_colors]
return get_palette
def get_view_styles(num_styles: int, is_main: bool) -> List[ViewStyle]:
global palette
base_color = (255, 255, 255)
opacity = 1
colors = init_palettes()(max(num_styles, 100))
colors = colors[torch.rand(100).argsort()][:num_styles].tolist()
colors = map(lambda x: list(map(lambda c: int(255 * c), x[:3])), colors)
# if len(palette_) < num_styles:
# palette_ = palette_ + [tuple(item) for item in torch.randint(255, size=(num_styles - len(palette_), 3)).tolist()]
view_styles = []
for i, color in enumerate(colors):
if is_main:
view_styles.append(ViewStyle(base_color, base_color, color, opacity))
else:
view_styles.append(ViewStyle(base_color, color, color, opacity))
return view_styles