File size: 2,763 Bytes
3c27e5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ca1ed9
3c27e5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torchvision
import clip
import matplotlib.pyplot as plt
import seaborn as sns
import gradio as gr


DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

model_name = 'ViT-B/16' #@param  ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16']
model, preprocess = clip.load(model_name)

model.to(DEVICE).eval()
resolution = model.visual.input_resolution
resizer = torchvision.transforms.Resize(size=(resolution, resolution))


def create_rgb_tensor(color):
  """color is e.g. [1,0,0]"""
  return torch.tensor(color, device=DEVICE).reshape((1, 3, 1, 1))

def encode_color(color):
  """color is e.g. [1,0,0]"""
  rgb = create_rgb_tensor(color)
  return model.encode_image( resizer(rgb) )

def encode_text(text):
  tokenized_text = clip.tokenize(text).to(DEVICE)
  return model.encode_text(tokenized_text)

class RGBModel(torch.nn.Module):
    def __init__(self, device):
      # Call nn.Module.__init__() to instantiate typical torch.nn.Module stuff
      super(RGBModel, self).__init__()
      self.color = torch.nn.Parameter(torch.ones((1, 3, 1, 1), device=device) / 2)

    def forward(self):
      # Clamp numbers to the closed interval [0,1]
      self.color.data = self.color.data.clamp(0,1)

      return self.color

text_input = gr.inputs.Textbox(lines=1, label="Text Prompt", default='A solid red square')
steps_input = gr.inputs.Slider(minimum=1, maximum=100, step=1, default=11, label="Training Steps")
lr_input = gr.inputs.Number(default=0.06, label="Adam Optimizer Learning Rate")
decay_input = gr.inputs.Number(default=0.01, label="Adam Optimizer Weight Decay")



def gradio_fn(text_prompt, adam_learning_rate, adam_weight_decay, n_iterations=50):

  rgb_model = RGBModel(device=DEVICE)
  opt = torch.optim.AdamW([rgb_model()], lr=adam_learning_rate, weight_decay=adam_weight_decay)

  with torch.no_grad():
    tokenized_text = clip.tokenize(text_prompt).to(DEVICE)
    target_embedding = model.encode_text(tokenized_text).detach().clone()

  def training_step():
    opt.zero_grad()
    color = rgb_model()
    color_img = resizer(color)
    image_embedding = model.encode_image(color_img)
    loss = -1 * torch.cosine_similarity(target_embedding, image_embedding, dim=-1)
    loss.backward()
    opt.step()

  steps = []
  steps.append(rgb_model().cpu().detach().numpy())
  for iteration in range(n_iterations):
    training_step()
    steps.append(rgb_model().cpu().detach().numpy())

  steps = np.stack([steps])

  img_train = Image.fromarray((steps[:,:,0,:,0,0] * 255).astype(np.uint8)).resize((400, 100), 0)

  return img_train

iface = gr.Interface( fn=gradio_fn, inputs=[text_input, lr_input, decay_input, steps_input], outputs="image")
iface.launch()