adirik commited on
Commit
a81e960
β€’
1 Parent(s): bb7ed14
Files changed (1) hide show
  1. app.py +29 -12
app.py CHANGED
@@ -28,8 +28,11 @@ def process_image(image, prompt):
28
  outputs = model(**inputs)
29
  preds = outputs.logits
30
 
31
- filename = f"mask.png"
32
- plt.imsave(filename, torch.sigmoid(preds))
 
 
 
33
  return Image.open("mask.png").convert("RGB")
34
 
35
 
@@ -39,8 +42,8 @@ def read_content(file_path):
39
  return content
40
 
41
 
42
- def predict(dict, text_query, reference, scale, seed, step):
43
- width, height = dict["image"].size
44
  if width < height:
45
  factor = width / 512.0
46
  width = 512
@@ -51,10 +54,10 @@ def predict(dict, text_query, reference, scale, seed, step):
51
  height = 512
52
  width = int((width / factor) / 8.0) * 8
53
 
54
- init_image = dict["image"].convert("RGB").resize((width, height))
55
- mask = dict["mask"].convert("RGB").resize((width, height))
56
- print(np.array(mask))
57
- print(text_query)
58
  generator = torch.Generator('cuda').manual_seed(seed) if seed != 0 else None
59
  output = pipe(
60
  image=init_image,
@@ -122,7 +125,7 @@ with image_blocks as demo:
122
  with gr.Box():
123
  with gr.Row():
124
  with gr.Column():
125
- image = gr.Image(source="upload", tool="sketch", elem_id="image_upload", type="pil", label="Source Image")
126
  text = gr.Textbox(lines=1, placeholder="Clothing item you want to replace...")
127
  reference = gr.Image(source="upload", elem_id="image_upload", type="pil", label="Reference Image")
128
 
@@ -146,11 +149,25 @@ with image_blocks as demo:
146
 
147
  with gr.Row():
148
  with gr.Column():
149
- gr.Examples(image_list, inputs=[image],label="Examples - Source Image",examples_per_page=12)
 
 
 
 
 
150
  with gr.Column():
151
- gr.Examples(ref_list, inputs=[reference],label="Examples - Reference Image",examples_per_page=12)
 
 
 
 
 
152
 
153
- btn.click(fn=predict, inputs=[image, text, reference, guidance, seed, steps], outputs=[image_out, community_icon, loading_icon, share_button])
 
 
 
 
154
  share_button.click(None, [], [], _js=share_js)
155
 
156
  gr.HTML(
 
28
  outputs = model(**inputs)
29
  preds = outputs.logits
30
 
31
+ filename = "mask.png"
32
+ preds = torch.sigmoid(preds)
33
+ preds[preds >= 0.5] = 1
34
+ preds[preds < 0.5] = 0
35
+ plt.imsave(filename, preds)
36
  return Image.open("mask.png").convert("RGB")
37
 
38
 
 
42
  return content
43
 
44
 
45
+ def predict(input_image, text_query, reference, scale, seed, step):
46
+ width, height = input_image.size
47
  if width < height:
48
  factor = width / 512.0
49
  width = 512
 
54
  height = 512
55
  width = int((width / factor) / 8.0) * 8
56
 
57
+ init_image = input_image.convert("RGB").resize((width, height))
58
+ mask = process_image(input_image, text_query).resize((width, height))
59
+ #mask = dict["mask"].convert("RGB").resize((width, height))
60
+
61
  generator = torch.Generator('cuda').manual_seed(seed) if seed != 0 else None
62
  output = pipe(
63
  image=init_image,
 
125
  with gr.Box():
126
  with gr.Row():
127
  with gr.Column():
128
+ image = gr.Image(source="upload", elem_id="image_upload", type="pil", label="Source Image")
129
  text = gr.Textbox(lines=1, placeholder="Clothing item you want to replace...")
130
  reference = gr.Image(source="upload", elem_id="image_upload", type="pil", label="Reference Image")
131
 
 
149
 
150
  with gr.Row():
151
  with gr.Column():
152
+ gr.Examples(
153
+ image_list,
154
+ inputs=[image],
155
+ label="Examples - Source Image",
156
+ examples_per_page=12
157
+ )
158
  with gr.Column():
159
+ gr.Examples(
160
+ ref_list,
161
+ inputs=[reference],
162
+ label="Examples - Reference Image",
163
+ examples_per_page=12
164
+ )
165
 
166
+ btn.click(
167
+ fn=predict,
168
+ inputs=[image, text, reference, guidance, seed, steps],
169
+ outputs=[image_out, community_icon, loading_icon, share_button]
170
+ )
171
  share_button.click(None, [], [], _js=share_js)
172
 
173
  gr.HTML(