|
import os |
|
import torch |
|
import numpy as np |
|
import imgui |
|
import dnnlib |
|
from gui_utils import imgui_utils |
|
|
|
|
|
|
|
class DragWidget: |
|
def __init__(self, viz): |
|
self.viz = viz |
|
self.point = [-1, -1] |
|
self.points = [] |
|
self.targets = [] |
|
self.is_point = True |
|
self.last_click = False |
|
self.is_drag = False |
|
self.iteration = 0 |
|
self.mode = 'point' |
|
self.r_mask = 50 |
|
self.show_mask = False |
|
self.mask = torch.ones(256, 256) |
|
self.lambda_mask = 20 |
|
self.feature_idx = 5 |
|
self.r1 = 3 |
|
self.r2 = 12 |
|
self.path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '_screenshots')) |
|
self.defer_frames = 0 |
|
self.disabled_time = 0 |
|
|
|
def action(self, click, down, x, y): |
|
if self.mode == 'point': |
|
self.add_point(click, x, y) |
|
elif down: |
|
self.draw_mask(x, y) |
|
|
|
def add_point(self, click, x, y): |
|
if click: |
|
self.point = [y, x] |
|
elif self.last_click: |
|
if self.is_drag: |
|
self.stop_drag() |
|
if self.is_point: |
|
self.points.append(self.point) |
|
self.is_point = False |
|
else: |
|
self.targets.append(self.point) |
|
self.is_point = True |
|
self.last_click = click |
|
|
|
def init_mask(self, w, h): |
|
self.width, self.height = w, h |
|
self.mask = torch.ones(h, w) |
|
|
|
def draw_mask(self, x, y): |
|
X = torch.linspace(0, self.width, self.width) |
|
Y = torch.linspace(0, self.height, self.height) |
|
yy, xx = torch.meshgrid(Y, X) |
|
circle = (xx - x)**2 + (yy - y)**2 < self.r_mask**2 |
|
if self.mode == 'flexible': |
|
self.mask[circle] = 0 |
|
elif self.mode == 'fixed': |
|
self.mask[circle] = 1 |
|
|
|
def stop_drag(self): |
|
self.is_drag = False |
|
self.iteration = 0 |
|
|
|
def set_points(self, points): |
|
self.points = points |
|
|
|
def reset_point(self): |
|
self.points = [] |
|
self.targets = [] |
|
self.is_point = True |
|
|
|
def load_points(self, suffix): |
|
points = [] |
|
point_path = self.path + f'_{suffix}.txt' |
|
try: |
|
with open(point_path, "r") as f: |
|
for line in f.readlines(): |
|
y, x = line.split() |
|
points.append([int(y), int(x)]) |
|
except: |
|
print(f'Wrong point file path: {point_path}') |
|
return points |
|
|
|
@imgui_utils.scoped_by_object_id |
|
def __call__(self, show=True): |
|
viz = self.viz |
|
reset = False |
|
if show: |
|
with imgui_utils.grayed_out(self.disabled_time != 0): |
|
imgui.text('Drag') |
|
imgui.same_line(viz.label_w) |
|
|
|
if imgui_utils.button('Add point', width=viz.button_w, enabled='image' in viz.result): |
|
self.mode = 'point' |
|
|
|
imgui.same_line() |
|
reset = False |
|
if imgui_utils.button('Reset point', width=viz.button_w, enabled='image' in viz.result): |
|
self.reset_point() |
|
reset = True |
|
|
|
imgui.text(' ') |
|
imgui.same_line(viz.label_w) |
|
if imgui_utils.button('Start', width=viz.button_w, enabled='image' in viz.result): |
|
self.is_drag = True |
|
if len(self.points) > len(self.targets): |
|
self.points = self.points[:len(self.targets)] |
|
|
|
imgui.same_line() |
|
if imgui_utils.button('Stop', width=viz.button_w, enabled='image' in viz.result): |
|
self.stop_drag() |
|
|
|
imgui.text(' ') |
|
imgui.same_line(viz.label_w) |
|
imgui.text(f'Steps: {self.iteration}') |
|
|
|
imgui.text('Mask') |
|
imgui.same_line(viz.label_w) |
|
if imgui_utils.button('Flexible area', width=viz.button_w, enabled='image' in viz.result): |
|
self.mode = 'flexible' |
|
self.show_mask = True |
|
|
|
imgui.same_line() |
|
if imgui_utils.button('Fixed area', width=viz.button_w, enabled='image' in viz.result): |
|
self.mode = 'fixed' |
|
self.show_mask = True |
|
|
|
imgui.text(' ') |
|
imgui.same_line(viz.label_w) |
|
if imgui_utils.button('Reset mask', width=viz.button_w, enabled='image' in viz.result): |
|
self.mask = torch.ones(self.height, self.width) |
|
imgui.same_line() |
|
_clicked, self.show_mask = imgui.checkbox('Show mask', self.show_mask) |
|
|
|
imgui.text(' ') |
|
imgui.same_line(viz.label_w) |
|
with imgui_utils.item_width(viz.font_size * 6): |
|
changed, self.r_mask = imgui.input_int('Radius', self.r_mask) |
|
|
|
imgui.text(' ') |
|
imgui.same_line(viz.label_w) |
|
with imgui_utils.item_width(viz.font_size * 6): |
|
changed, self.lambda_mask = imgui.input_int('Lambda', self.lambda_mask) |
|
|
|
self.disabled_time = max(self.disabled_time - viz.frame_delta, 0) |
|
if self.defer_frames > 0: |
|
self.defer_frames -= 1 |
|
viz.args.is_drag = self.is_drag |
|
if self.is_drag: |
|
self.iteration += 1 |
|
viz.args.iteration = self.iteration |
|
viz.args.points = [point for point in self.points] |
|
viz.args.targets = [point for point in self.targets] |
|
viz.args.mask = self.mask |
|
viz.args.lambda_mask = self.lambda_mask |
|
viz.args.feature_idx = self.feature_idx |
|
viz.args.r1 = self.r1 |
|
viz.args.r2 = self.r2 |
|
viz.args.reset = reset |
|
|
|
|
|
|
|
|