miccull commited on
Commit
3c27e5d
1 Parent(s): 43d1aca

initial commit

Browse files
Files changed (2) hide show
  1. app.py +85 -0
  2. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from PIL import Image
4
+ import torch
5
+ import torchvision
6
+ import clip
7
+ import matplotlib.pyplot as plt
8
+ import seaborn as sns
9
+ import gradio as gr
10
+
11
+
12
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
13
+
14
+ model_name = 'ViT-B/16' #@param ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16']
15
+ model, preprocess = clip.load(model_name)
16
+
17
+ model.to(DEVICE).eval()
18
+ resolution = model.visual.input_resolution
19
+ resizer = torchvision.transforms.Resize(size=(resolution, resolution))
20
+
21
+
22
+ def create_rgb_tensor(color):
23
+ """color is e.g. [1,0,0]"""
24
+ return torch.tensor(color, device=DEVICE).reshape((1, 3, 1, 1))
25
+
26
+ def encode_color(color):
27
+ """color is e.g. [1,0,0]"""
28
+ rgb = create_rgb_tensor(color)
29
+ return model.encode_image( resizer(rgb) )
30
+
31
+ def encode_text(text):
32
+ tokenized_text = clip.tokenize(text).to(DEVICE)
33
+ return model.encode_text(tokenized_text)
34
+
35
+ class RGBModel(torch.nn.Module):
36
+ def __init__(self, device):
37
+ # Call nn.Module.__init__() to instantiate typical torch.nn.Module stuff
38
+ super(RGBModel, self).__init__()
39
+ self.color = torch.nn.Parameter(torch.ones((1, 3, 1, 1), device=device) / 2)
40
+
41
+ def forward(self):
42
+ # Clamp numbers to the closed interval [0,1]
43
+ self.color.data = self.color.data.clamp(0,1)
44
+
45
+ return self.color
46
+
47
+ text_input = gr.inputs.Textbox(lines=1, label="Text Prompt", default='A solid red square')
48
+ steps_input = gr.inputs.Slider(minimum=1, maximum=100, step=1, default=11, label="Training Steps")
49
+ lr_input = gr.inputs.Number(default=0.06, label="Adam Optimizer Learning Rate")
50
+ decay_input = gr.inputs.Number(default=0.01, label="Adam Optimizer Weight Decay")
51
+
52
+
53
+
54
+ def gradio_fn(text_prompt, adam_learning_rate, adam_weight_decay, n_iterations=50):
55
+
56
+ rgb_model = RGBModel(device=DEVICE)
57
+ opt = torch.optim.AdamW([rgb_model()], lr=adam_learning_rate, weight_decay=adam_weight_decay)
58
+
59
+ with torch.no_grad():
60
+ tokenized_text = clip.tokenize(text_prompt).cuda()
61
+ target_embedding = model.encode_text(tokenized_text).detach().clone()
62
+
63
+ def training_step():
64
+ opt.zero_grad()
65
+ color = rgb_model()
66
+ color_img = resizer(color)
67
+ image_embedding = model.encode_image(color_img)
68
+ loss = -1 * torch.cosine_similarity(target_embedding, image_embedding, dim=-1)
69
+ loss.backward()
70
+ opt.step()
71
+
72
+ steps = []
73
+ steps.append(rgb_model().cpu().detach().numpy())
74
+ for iteration in range(n_iterations):
75
+ training_step()
76
+ steps.append(rgb_model().cpu().detach().numpy())
77
+
78
+ steps = np.stack([steps])
79
+
80
+ img_train = Image.fromarray((steps[:,:,0,:,0,0] * 255).astype(np.uint8)).resize((400, 100), 0)
81
+
82
+ return img_train
83
+
84
+ iface = gr.Interface( fn=gradio_fn, inputs=[text_input, lr_input, decay_input, steps_input], outputs="image")
85
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ comet_ml
2
+ ftfy
3
+ regex
4
+ git+https://github.com/openai/CLIP.git
5
+ pandas
6
+ Pillow
7
+ tqdm
8
+ torch
9
+ torchvision
10
+ matplotlib
11
+ seaborn