DragGan / viz /drag_widget.py
radames's picture
merge with main
e722a6c
raw
history blame
6.1 kB
import os
import torch
import numpy as np
import imgui
import dnnlib
from gui_utils import imgui_utils
#----------------------------------------------------------------------------
class DragWidget:
def __init__(self, viz):
self.viz = viz
self.point = [-1, -1]
self.points = []
self.targets = []
self.is_point = True
self.last_click = False
self.is_drag = False
self.iteration = 0
self.mode = 'point'
self.r_mask = 50
self.show_mask = False
self.mask = torch.ones(256, 256)
self.lambda_mask = 20
self.feature_idx = 5
self.r1 = 3
self.r2 = 12
self.path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '_screenshots'))
self.defer_frames = 0
self.disabled_time = 0
def action(self, click, down, x, y):
if self.mode == 'point':
self.add_point(click, x, y)
elif down:
self.draw_mask(x, y)
def add_point(self, click, x, y):
if click:
self.point = [y, x]
elif self.last_click:
if self.is_drag:
self.stop_drag()
if self.is_point:
self.points.append(self.point)
self.is_point = False
else:
self.targets.append(self.point)
self.is_point = True
self.last_click = click
def init_mask(self, w, h):
self.width, self.height = w, h
self.mask = torch.ones(h, w)
def draw_mask(self, x, y):
X = torch.linspace(0, self.width, self.width)
Y = torch.linspace(0, self.height, self.height)
yy, xx = torch.meshgrid(Y, X)
circle = (xx - x)**2 + (yy - y)**2 < self.r_mask**2
if self.mode == 'flexible':
self.mask[circle] = 0
elif self.mode == 'fixed':
self.mask[circle] = 1
def stop_drag(self):
self.is_drag = False
self.iteration = 0
def set_points(self, points):
self.points = points
def reset_point(self):
self.points = []
self.targets = []
self.is_point = True
def load_points(self, suffix):
points = []
point_path = self.path + f'_{suffix}.txt'
try:
with open(point_path, "r") as f:
for line in f.readlines():
y, x = line.split()
points.append([int(y), int(x)])
except:
print(f'Wrong point file path: {point_path}')
return points
@imgui_utils.scoped_by_object_id
def __call__(self, show=True):
viz = self.viz
reset = False
if show:
with imgui_utils.grayed_out(self.disabled_time != 0):
imgui.text('Drag')
imgui.same_line(viz.label_w)
if imgui_utils.button('Add point', width=viz.button_w, enabled='image' in viz.result):
self.mode = 'point'
imgui.same_line()
reset = False
if imgui_utils.button('Reset point', width=viz.button_w, enabled='image' in viz.result):
self.reset_point()
reset = True
imgui.text(' ')
imgui.same_line(viz.label_w)
if imgui_utils.button('Start', width=viz.button_w, enabled='image' in viz.result):
self.is_drag = True
if len(self.points) > len(self.targets):
self.points = self.points[:len(self.targets)]
imgui.same_line()
if imgui_utils.button('Stop', width=viz.button_w, enabled='image' in viz.result):
self.stop_drag()
imgui.text(' ')
imgui.same_line(viz.label_w)
imgui.text(f'Steps: {self.iteration}')
imgui.text('Mask')
imgui.same_line(viz.label_w)
if imgui_utils.button('Flexible area', width=viz.button_w, enabled='image' in viz.result):
self.mode = 'flexible'
self.show_mask = True
imgui.same_line()
if imgui_utils.button('Fixed area', width=viz.button_w, enabled='image' in viz.result):
self.mode = 'fixed'
self.show_mask = True
imgui.text(' ')
imgui.same_line(viz.label_w)
if imgui_utils.button('Reset mask', width=viz.button_w, enabled='image' in viz.result):
self.mask = torch.ones(self.height, self.width)
imgui.same_line()
_clicked, self.show_mask = imgui.checkbox('Show mask', self.show_mask)
imgui.text(' ')
imgui.same_line(viz.label_w)
with imgui_utils.item_width(viz.font_size * 6):
changed, self.r_mask = imgui.input_int('Radius', self.r_mask)
imgui.text(' ')
imgui.same_line(viz.label_w)
with imgui_utils.item_width(viz.font_size * 6):
changed, self.lambda_mask = imgui.input_int('Lambda', self.lambda_mask)
self.disabled_time = max(self.disabled_time - viz.frame_delta, 0)
if self.defer_frames > 0:
self.defer_frames -= 1
viz.args.is_drag = self.is_drag
if self.is_drag:
self.iteration += 1
viz.args.iteration = self.iteration
viz.args.points = [point for point in self.points]
viz.args.targets = [point for point in self.targets]
viz.args.mask = self.mask
viz.args.lambda_mask = self.lambda_mask
viz.args.feature_idx = self.feature_idx
viz.args.r1 = self.r1
viz.args.r2 = self.r2
viz.args.reset = reset
#----------------------------------------------------------------------------