Spaces:
Runtime error
Runtime error
File size: 2,763 Bytes
3c27e5d 9ca1ed9 3c27e5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torchvision
import clip
import matplotlib.pyplot as plt
import seaborn as sns
import gradio as gr
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model_name = 'ViT-B/16' #@param ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16']
model, preprocess = clip.load(model_name)
model.to(DEVICE).eval()
resolution = model.visual.input_resolution
resizer = torchvision.transforms.Resize(size=(resolution, resolution))
def create_rgb_tensor(color):
"""color is e.g. [1,0,0]"""
return torch.tensor(color, device=DEVICE).reshape((1, 3, 1, 1))
def encode_color(color):
"""color is e.g. [1,0,0]"""
rgb = create_rgb_tensor(color)
return model.encode_image( resizer(rgb) )
def encode_text(text):
tokenized_text = clip.tokenize(text).to(DEVICE)
return model.encode_text(tokenized_text)
class RGBModel(torch.nn.Module):
def __init__(self, device):
# Call nn.Module.__init__() to instantiate typical torch.nn.Module stuff
super(RGBModel, self).__init__()
self.color = torch.nn.Parameter(torch.ones((1, 3, 1, 1), device=device) / 2)
def forward(self):
# Clamp numbers to the closed interval [0,1]
self.color.data = self.color.data.clamp(0,1)
return self.color
text_input = gr.inputs.Textbox(lines=1, label="Text Prompt", default='A solid red square')
steps_input = gr.inputs.Slider(minimum=1, maximum=100, step=1, default=11, label="Training Steps")
lr_input = gr.inputs.Number(default=0.06, label="Adam Optimizer Learning Rate")
decay_input = gr.inputs.Number(default=0.01, label="Adam Optimizer Weight Decay")
def gradio_fn(text_prompt, adam_learning_rate, adam_weight_decay, n_iterations=50):
rgb_model = RGBModel(device=DEVICE)
opt = torch.optim.AdamW([rgb_model()], lr=adam_learning_rate, weight_decay=adam_weight_decay)
with torch.no_grad():
tokenized_text = clip.tokenize(text_prompt).to(DEVICE)
target_embedding = model.encode_text(tokenized_text).detach().clone()
def training_step():
opt.zero_grad()
color = rgb_model()
color_img = resizer(color)
image_embedding = model.encode_image(color_img)
loss = -1 * torch.cosine_similarity(target_embedding, image_embedding, dim=-1)
loss.backward()
opt.step()
steps = []
steps.append(rgb_model().cpu().detach().numpy())
for iteration in range(n_iterations):
training_step()
steps.append(rgb_model().cpu().detach().numpy())
steps = np.stack([steps])
img_train = Image.fromarray((steps[:,:,0,:,0,0] * 255).astype(np.uint8)).resize((400, 100), 0)
return img_train
iface = gr.Interface( fn=gradio_fn, inputs=[text_input, lr_input, decay_input, steps_input], outputs="image")
iface.launch() |