Spaces:
Sleeping
Sleeping
import torch | |
import vtk | |
from utils import files_utils | |
from custom_types import * | |
import functools | |
import matplotlib.pyplot as plt | |
bg_source_color = (152, 181, 234) | |
bg_target_color = (250, 200, 152) | |
button_color = (255, 0, 255) | |
bg_menu_color = (214, 139, 202) | |
bg_stage_color = (255, 180, 110) | |
default_colors = [(82, 108, 255), (160, 82, 255), (255, 43, 43), (255, 246, 79), | |
(153, 227, 107), (58, 186, 92), (8, 243, 255), (240, 136, 0)] | |
class SmoothingMethod(Enum): | |
Laplace = "laplace" | |
Taubin = "taubin" | |
class EditType(enum.Enum): | |
Pondering = 'pondering' | |
Translating = 'translating' | |
Rotating = 'rotating' | |
Scaling = 'scaling' | |
Marking = 'marking' | |
class EditDirection(enum.Enum): | |
X_Axis = 'axis_x' | |
Y_Axis = 'axis_y' | |
Z_Axis = 'axis_z' | |
palette = ( | |
(63, 72, 204), | |
(51, 213, 73), | |
(213, 51, 159), | |
(153, 227, 107), | |
(246, 162, 81) | |
) | |
# palette = [(.6196, 0.0039, 0.2588), | |
# (.6873, 0.0790, 0.2748), | |
# (.7549, 0.1540, 0.2908), | |
# (.8226, 0.2291, 0.3068), | |
# (.8710, 0.2973, 0.2960), | |
# (.9092, 0.3552, 0.2812), | |
# (.9473, 0.4130, 0.2664), | |
# (.9652, 0.4874, 0.2904), | |
# (.9776, 0.5774, 0.3319), | |
# (.9887, 0.6574, 0.3689), | |
# (.9930, 0.7246, 0.4159), | |
# (.9942, 0.7862, 0.4676), | |
# (.9956, 0.8554, 0.5257), | |
# (.9968, 0.9023, 0.5851), | |
# (.9981, 0.9404, 0.6491), | |
# (.9993, 0.9785, 0.7130), | |
# (.9827, 0.9931, 0.7220), | |
# (.9519, 0.9808, 0.6740), | |
# (.9212, 0.9685, 0.6261), | |
# (.8747, 0.9497, 0.6016), | |
# (.7931, 0.9165, 0.6182), | |
# (.7205, 0.8870, 0.6330), | |
# (.6441, 0.8563, 0.6435), | |
# (.5592, 0.8231, 0.6448), | |
# (.4637, 0.7857, 0.6461), | |
# (.3840, 0.7429, 0.6544), | |
# (.3200, 0.6716, 0.6840), | |
# (.2561, 0.6002, 0.7135), | |
# (.2062, 0.5202, 0.7349), | |
# (.2604, 0.4501, 0.7017), | |
# (.3145, 0.3799, 0.6685), | |
# (.3686, 0.3098, 0.6353)] | |
RGB_COLOR = Union[Tuple[int, int, int], List[int]] | |
RGB_FLOAT_COLOR = Union[Tuple[float, float, float], List[float]] | |
RGBA_COLOR = Union[Tuple[int, int, int, int], List[int]] | |
RGBA_FLOAT_COLOR = Union[Tuple[float, float, float, float], List[float]] | |
def channel_to_float(*channel: int): | |
if type(channel[0]) is float and 0 <= channel[0] <= 1: | |
return channel | |
return [c / 255. for c in channel] | |
def rgb_to_float(*colors: RGB_COLOR) -> Union[RGB_FLOAT_COLOR, List[RGB_FLOAT_COLOR]]: | |
float_colors = [channel_to_float(*c) for c in colors] | |
if len(float_colors) == 1: | |
return float_colors[0] | |
return float_colors | |
def rgb_to_rgba_float(color: RGB_COLOR, alpha: float) -> RGBA_FLOAT_COLOR: | |
color = list(rgb_to_float(color)) + [alpha] | |
return color | |
class Buttons(enum.Enum): | |
translate = 'T' | |
rotate = 'R' | |
stretch = 'S' | |
reset = 'reset' | |
update = 'hq' | |
symmetric = 'symmetric' | |
empty = -1 | |
class ViewStyle: | |
def __init__(self, base_color: RGB_COLOR, included_color: RGB_COLOR, selected_color: RGB_COLOR, | |
opacity: float): | |
self.base_color = rgb_to_float(base_color) | |
self.included_color = rgb_to_float(included_color) | |
self.stroke_color = list(selected_color) + [200] | |
self.selected_color = rgb_to_float(selected_color) | |
self.opacity = opacity | |
class Transition: | |
def __init__(self, transition_origin: ARRAY, transition_type: EditType): | |
self.transition_origin: ARRAY = transition_origin | |
self.transition_type: ARRAY = transition_type | |
self.translation: ARRAY = np.zeros(3) | |
self.rotation: ARRAY = np.eye(3) | |
def get_rotation_matrix(theta: float, axis: float) -> ARRAY: | |
rotate_mat = np.eye(3) | |
rotate_mat[axis, axis] = 1 | |
cos_theta, sin_theta = np.cos(theta), np.sin(theta) | |
rotate_mat[(axis + 1) % 3, (axis + 1) % 3] = cos_theta | |
rotate_mat[(axis + 2) % 3, (axis + 2) % 3] = cos_theta | |
rotate_mat[(axis + 1) % 3, (axis + 2) % 3] = sin_theta | |
rotate_mat[(axis + 2) % 3, (axis + 1) % 3] = -sin_theta | |
return rotate_mat | |
def load_vtk(path: str, vtk_reader): | |
vtk_reader.SetFileName(path) | |
vtk_reader.Update() | |
source = vtk_reader.GetOutput() | |
return source | |
def save_vtk(data, path: str, vtk_writer): | |
vtk_writer.SetFileName(path) | |
vtk_writer.SetInputData(data) | |
vtk_writer.Update() | |
vtk_writer.Write() | |
def load_vtk_obj(path: str): | |
path = files_utils.add_suffix(path, ".obj") | |
return load_vtk(path, vtk.vtkOBJReader()) | |
def save_vtk_image(data, path: str): | |
path = files_utils.add_suffix(path, ".vtk") | |
files_utils.init_folders(path) | |
save_vtk(data, path, vtk.vtkXMLImageDataWriter()) | |
def load_vtk_image(path: str) -> vtk.vtkImageData: | |
path = files_utils.add_suffix(path, ".vtk") | |
return load_vtk(path, vtk.vtkXMLImageDataReader()) | |
def set_default_properties(actor: vtk.vtkActor, color: Tuple[float, float, float]): | |
properties = actor.GetProperty() | |
properties.SetPointSize(10) | |
properties.SetDiffuseColor(.6, .6, .6) | |
properties.SetAmbient(.2) | |
properties.SetDiffuse(.8) | |
properties.SetSpecular(.5) | |
properties.SetSpecularColor(.2, .2, .2) | |
properties.SetSpecularPower(30.0) | |
properties.SetColor(*color) | |
return actor | |
def wrap_mesh(source, color): | |
mapper = vtk.vtkPolyDataMapper() | |
mapper.SetInputData(source) | |
actor = vtk.vtkActor() | |
actor.SetMapper(mapper) | |
actor = set_default_properties(actor, color) | |
return actor, mapper | |
def create_vtk_image(path: str) -> vtk.vtkImageData: | |
root, name, _ = files_utils.split_path(path) | |
cache_image_path = f"{root}/cache/{name}.vtk" | |
if not files_utils.is_file(cache_image_path): | |
np_image = files_utils.load_image(path, 'RGBA') | |
image = vtk.vtkImageData() | |
image.SetDimensions(np_image.shape[1], np_image.shape[0], 1) | |
image.AllocateScalars(vtk.VTK_UNSIGNED_CHAR, np_image.shape[2]) | |
dims = image.GetDimensions() | |
for y in range(dims[1]): | |
for x in range(dims[0]): | |
pixel = np_image[dims[1] - 1 - y, x] | |
for i in range(np_image.shape[2]): | |
image.SetScalarComponentFromDouble(x, y, 0, i, pixel[i]) | |
# points = image.GetPointData().GetArray(0) | |
save_vtk_image(image, cache_image_path) | |
else: | |
image = load_vtk_image(cache_image_path) | |
return image | |
class ImageButton(vtk.vtkButtonWidget): | |
def process_state_change_event(self, obj, event): | |
print(f"end event {self.button_representation.GetState()}") | |
def set_size(self, window_size: Tuple[float, float]): | |
w, h = window_size[0] * self.full_size[0], window_size[1] * self.full_size[1] | |
pos_left, pos_top = int(w * self.position[0]), int(h * self.position[1]) | |
position_coords = [pos_left, | |
pos_left + int(w * self.size[0]), | |
pos_top - int(h *self.size[1]), | |
pos_top, | |
0, 0] | |
self.button_representation.PlaceWidget(position_coords) | |
def resize_event(self, obj, event): | |
self.set_size(obj.GetSize()) | |
def __init__(self, images_paths: List[str], interactor, render, size: Union[float, Tuple[float, float]], | |
position: Tuple[float, float], on_click: Optional[Callable[[Any, Any], None]] = None, | |
full_size: Tuple[float, float] = (1., 1.)): | |
super(ImageButton, self).__init__() | |
self.SetCurrentRenderer(render) | |
if type(size) is float: | |
size = (size, size) | |
self.full_size = full_size | |
render_window: vtk.vtkRenderWindow = interactor.GetRenderWindow() | |
images = map(lambda x: create_vtk_image(x), images_paths) | |
self.button_representation = vtk.vtkTexturedButtonRepresentation2D() | |
self.button_representation.SetNumberOfStates(len(images_paths)) | |
self.button_representation.GetProperty().SetColor(1, 1, 1) | |
for i, image in enumerate(images): | |
self.button_representation.SetButtonTexture(i, image) | |
self.SetInteractor(interactor) | |
self.SetRepresentation(self.button_representation) | |
self.size = size | |
self.position = position | |
self.button_representation.SetPlaceFactor(1) | |
self.set_size(render_window.GetSize()) | |
render_window.AddObserver(vtk.vtkCommand.WindowResizeEvent, self.resize_event) | |
if on_click is not None: | |
self.AddObserver(vtk.vtkCommand.StateChangedEvent, on_click) | |
self.On() | |
selection_prop = self.button_representation.GetSelectingProperty() | |
selection_prop.SetLineWidth(0.) | |
selection_prop.SetColor(1., 1., 1.) | |
def make_slider(iren, observer): | |
to_show = False | |
if to_show: | |
ren_left = vtk.vtkRenderer() | |
ren_left.SetBackground(*rgb_to_float((250, 255, 255))) | |
ren_window = vtk.vtkRenderWindow() | |
ren_window.AddRenderer(ren_left) | |
iren = vtk.vtkRenderWindowInteractor() | |
iren.SetRenderWindow(ren_window) | |
ren_window.Render() | |
slider_repres = vtk.vtkSliderRepresentation2D() | |
slider_repres.SetMinimumValue(0) | |
slider_repres.SetMaximumValue(100.) | |
# slider_repres.SetTitleText('Mesh\nOpacity') | |
slider_repres.SetValue(30.) | |
slider_repres.GetSliderProperty().SetColor(*rgb_to_float(bg_target_color)) | |
slider_repres.ShowSliderLabelOff() | |
# slider_repres.GetLabelProperty().SetColor(1., 0., 0.) | |
slider_repres.GetCapProperty().SetColor(*rgb_to_float(bg_menu_color)) | |
slider_repres.GetSelectedProperty().SetColor(1., 0., 0) | |
slider_repres.GetTubeProperty().SetColor(*rgb_to_float(bg_source_color)) | |
slider_repres.GetPoint1Coordinate().SetCoordinateSystemToNormalizedDisplay() | |
slider_repres.GetPoint1Coordinate().SetValue(0.01, 0.1) | |
slider_repres.GetPoint2Coordinate().SetCoordinateSystemToNormalizedDisplay() | |
slider_repres.GetPoint2Coordinate().SetValue(0.23, 0.1) | |
slider_repres.SetSliderLength(0.01) | |
slider_repres.SetSliderWidth(0.01) | |
slider_repres.SetEndCapLength(0.01) | |
slider_repres.SetEndCapWidth(0.01) | |
slider_repres.SetTubeWidth(0.01) | |
slider_repres.SetLabelFormat('%f') | |
slider_widget = vtk.vtkSliderWidget() | |
slider_widget.SetInteractor(iren) | |
slider_widget.SetRepresentation(slider_repres) | |
slider_widget.KeyPressActivationOff() | |
slider_widget.SetAnimationModeToAnimate() | |
slider_widget.SetEnabled(True) | |
slider_widget.AddObserver('InteractionEvent', observer) | |
slider_widget.EnabledOn() | |
if to_show: | |
iren.Initialize() | |
ren_window.Render() | |
iren.Start() | |
del iren | |
del ren_window | |
return slider_widget, slider_repres | |
class CanvasRender(vtk.vtkRenderer): | |
def origin_x(self): | |
return self.viewport_ren[0] | |
def origin_y(self): | |
return self.viewport_ren[1] | |
def width(self): | |
return self.viewport_ren[2] - self.viewport_ren[0] | |
def height(self): | |
return self.viewport_ren[3] - self.viewport_ren[1] | |
def translate_point(self, pt: Tuple[int, int]) -> Tuple[int, int]: | |
return pt[0] - self.origin_x, pt[1] - self.origin_y | |
def get_mid_points(self, pt: Tuple[int, int]) -> List[List[int]]: | |
if self.last_point is None: | |
return [] | |
pt_a, pt_b = torch.tensor(pt, dtype=torch.float32), torch.tensor(self.last_point, dtype=torch.float32) | |
delta = pt_b - pt_a | |
num_mids = max(int(delta.norm(2, 0).item() / 10), 2) | |
# num_mids = 4 | |
mid_points = pt_a[None, :] + torch.linspace(0, 1, num_mids)[:, None] * delta[None, :] | |
mid_points[:, 0] += self.origin_x | |
mid_points[:, 1] += self.origin_y | |
return mid_points[:-1].long().tolist() | |
def draw(self, pt: Tuple[int, int], stroke_width: float = 5.) -> List[List[int]]: | |
pt = self.translate_point(pt) | |
if self.last_point is not None: | |
self.canvas.FillTube(*self.last_point, *pt, stroke_width) | |
self.canvas.Update() | |
mid_points = self.get_mid_points(pt) | |
self.last_point = pt | |
return mid_points | |
def clear(self): | |
self.last_point = None | |
self.canvas.SetDrawColor(0, 0, 0, 0) | |
self.canvas.FillBox(0, self.width, 0, self.height) | |
self.canvas.SetDrawColor(*self.stroke_color) | |
self.canvas.Update() | |
def resize_event_(self, obj): | |
self.viewport_ren = self.set_int_viewport(obj.GetSize()) | |
self.canvas.SetExtent(0, self.width, 0, self.height, 0, 0) | |
self.canvas.Update() | |
self.clear() | |
self.set_camera() | |
def resize_event(self, obj, event): | |
self.resize_event_(obj) | |
def set_camera(self): | |
origin = self.image_data.GetOrigin() | |
spacing = self.image_data.GetSpacing() | |
extent = self.image_data.GetExtent() | |
camera = self.canvas_render.GetActiveCamera() | |
camera.ParallelProjectionOn() | |
xc = origin[0] + 0.5 * (extent[0] + extent[1]) * spacing[0] | |
yc = origin[1] + 0.5 * (extent[2] + extent[3]) * spacing[1] | |
# xd = (extent[1] - extent[0] + 1) * spacing[0] | |
yd = (extent[3] - extent[2] + 1) * spacing[1] | |
d = camera.GetDistance() | |
camera.SetParallelScale(0.5 * yd) | |
camera.SetFocalPoint(xc, yc, 0.0) | |
camera.SetPosition(xc, yc, d) | |
def set_int_viewport(self, win_size) -> Tuple[int, int, int, int]: | |
w, h = win_size | |
return int(self.viewport[0] * w), int(self.viewport[1] * h), int(self.viewport[2] * w), int(self.viewport[3] * h) | |
def init_canvas(self): | |
self.canvas.SetExtent(0, self.width, 0, self.height, 0, 0) | |
self.canvas.PropagateUpdateExtent() | |
self.canvas.UpdateExtent((0, self.width, 0, self.height, 0, 0)) | |
self.canvas.SetScalarTypeToUnsignedChar() | |
self.canvas.SetNumberOfScalarComponents(4) | |
self.set_brush(True) | |
image_data = self.canvas.GetOutput() | |
image_actor = vtk.vtkImageActor() | |
image_actor.SetInputData(image_data) | |
self.canvas_render.AddActor(image_actor) | |
return image_data | |
def set_brush(self, is_draw: bool): | |
self.is_draw = is_draw | |
self.stroke_color = self.base_stroke_color if is_draw else (255, 255, 255, 200) | |
# (*bg_menu_color, 150) | |
self.canvas.SetDrawColor(*self.stroke_color) | |
self.canvas.Update() | |
def change_brush(self, stroke_color): | |
self.base_stroke_color = stroke_color | |
self.set_brush(self.is_draw) | |
def __init__(self, viewport: Tuple[float, float, float, float], render_window: vtk.vtkRenderWindow, | |
bg_color: RGB_COLOR, stroke_color: Optional[RGBA_COLOR] = None): | |
super(CanvasRender, self).__init__() | |
self.SetViewport(*viewport) | |
self.viewport = viewport | |
self.canvas_render = vtk.vtkRenderer() | |
self.canvas_render.SetViewport(*viewport) | |
if stroke_color is None: | |
stroke_color = vtk.vtkNamedColors().GetColor4ub('LightCoral') | |
stroke_color = stroke_color.GetRed(), stroke_color.GetGreen(), stroke_color.GetBlue(), 200 | |
self.base_stroke_color = self.stroke_color = stroke_color | |
# self.SetBackground(*bg_color) | |
self.is_draw = True | |
self.canvas_render.InteractiveOff() | |
self.viewport_ren = self.set_int_viewport(render_window.GetSize()) | |
self.canvas = vtk.vtkImageCanvasSource2D() | |
self.image_data = self.init_canvas() | |
render_window.AddObserver(vtk.vtkCommand.WindowResizeEvent, self.resize_event) | |
self.last_point: Optional[Tuple[int, int]] = None | |
self.SetLayer(0) | |
self.canvas_render.SetLayer(1) | |
self.SetBackground(*rgb_to_float(bg_color)) | |
render_window.AddRenderer(self) | |
render_window.AddRenderer(self.canvas_render) | |
self.set_camera() | |
def init_palettes(cmap='Spectral'): | |
colors = {} | |
color_map = plt.cm.get_cmap(cmap) | |
def get_palette(num_colors: int) -> T: | |
nonlocal colors, color_map | |
if num_colors == 1: | |
colors[num_colors] = torch.tensor([.45]) | |
if num_colors not in colors: | |
colors[num_colors] = torch.tensor([color_map(float(idx) / (num_colors - 1)) for idx in range(num_colors)]) | |
return colors[num_colors] | |
return get_palette | |
def get_view_styles(num_styles: int, is_main: bool) -> List[ViewStyle]: | |
global palette | |
base_color = (255, 255, 255) | |
opacity = 1 | |
colors = init_palettes()(max(num_styles, 100)) | |
colors = colors[torch.rand(100).argsort()][:num_styles].tolist() | |
colors = map(lambda x: list(map(lambda c: int(255 * c), x[:3])), colors) | |
# if len(palette_) < num_styles: | |
# palette_ = palette_ + [tuple(item) for item in torch.randint(255, size=(num_styles - len(palette_), 3)).tolist()] | |
view_styles = [] | |
for i, color in enumerate(colors): | |
if is_main: | |
view_styles.append(ViewStyle(base_color, base_color, color, opacity)) | |
else: | |
view_styles.append(ViewStyle(base_color, color, color, opacity)) | |
return view_styles | |