rishh76 commited on
Commit
493e780
·
verified ·
1 Parent(s): c015ca1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -45
app.py CHANGED
@@ -1,41 +1,48 @@
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
- import torch
5
- from diffusers import StableDiffusionPipeline
6
-
7
- # Load the Stable Diffusion model for text-based garment generation
8
- model_id = "runwayml/stable-diffusion-v1-5"
9
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
10
- pipe = pipe.to("cuda") # Use GPU for faster inference
11
-
12
- MAX_SEED = 999999
13
-
14
- def generate_garment(person_img, cloth_description, seed, randomize_seed):
15
- if person_img is None or cloth_description is None or cloth_description.strip() == "":
16
- return None, None, "Invalid input"
17
 
 
 
 
 
 
 
18
  if randomize_seed:
19
  seed = random.randint(0, MAX_SEED)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- # Generate garment image from the text description
22
- torch.manual_seed(seed)
23
- garment_img = pipe(cloth_description).images[0]
24
 
25
- # Combine the generated garment with the person's image
26
- result_img = combine_images(person_img, garment_img)
27
 
 
28
  return result_img, seed, "Success"
29
 
30
- def combine_images(person_img, garment_img):
31
- person_img = np.array(person_img)
32
- garment_img = np.array(garment_img.resize((person_img.shape[1], person_img.shape[0])))
33
 
34
- # Simple overlay of garment on the person image
35
- # Further improvement may require segmentation/masking
36
- result_img = np.where(garment_img[:, :, 3:] > 0, garment_img[:, :, :3], person_img)
37
 
38
- return result_img
 
39
 
40
  css = """
41
  #col-left {
@@ -54,32 +61,90 @@ css = """
54
  margin: 0 auto;
55
  max-width: 1100px;
56
  }
 
 
 
57
  """
58
 
 
 
 
 
 
 
59
  with gr.Blocks(css=css) as Tryon:
60
- gr.HTML("<h1>Virtual Try-On with Text-based Garment Generation</h1>")
61
-
62
  with gr.Row():
63
  with gr.Column(elem_id="col-left"):
64
- gr.HTML("<h3>Step 1: Upload a person image ⬇️</h3>")
65
- person_img = gr.Image(label="Person Image", source='upload', type="numpy")
66
-
 
 
 
 
67
  with gr.Column(elem_id="col-mid"):
68
- gr.HTML("<h3>Step 2: Describe the garment ⬇️</h3>")
69
- cloth_description = gr.Textbox(label="Garment Description", placeholder="e.g., red dress with floral pattern")
70
-
 
 
 
 
71
  with gr.Column(elem_id="col-right"):
72
- gr.HTML("<h3>Step 3: Generate Try-On Image ⬇️</h3>")
73
- result_img = gr.Image(label="Result", show_share_button=False)
74
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
75
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
76
- seed_used = gr.Number(label="Seed Used", interactive=False)
77
- result_info = gr.Text(label="Status", interactive=False)
78
-
79
- generate_button = gr.Button(value="Run")
80
-
81
- generate_button.click(fn=generate_garment,
82
- inputs=[person_img, cloth_description, seed, randomize_seed],
83
- outputs=[result_img, seed_used, result_info])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  Tryon.launch()
 
1
+ import os
2
+ import cv2
3
  import gradio as gr
4
  import numpy as np
5
  import random
6
+ import time
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ def tryon(person_img, garment_prompt, seed, randomize_seed):
9
+ post_start_time = time.time()
10
+
11
+ if person_img is None or garment_prompt.strip() == "":
12
+ return None, None, "Empty image or prompt"
13
+
14
  if randomize_seed:
15
  seed = random.randint(0, MAX_SEED)
16
+
17
+ # Create a copy of the person image to overlay text
18
+ result_img = person_img.copy()
19
+
20
+ # Convert the image to OpenCV format (if needed)
21
+ if len(result_img.shape) == 2: # Convert grayscale to RGB
22
+ result_img = cv2.cvtColor(result_img, cv2.COLOR_GRAY2RGB)
23
+
24
+ # Set text position and properties
25
+ text_position = (10, 30)
26
+ font = cv2.FONT_HERSHEY_SIMPLEX
27
+ font_scale = 1
28
+ font_color = (0, 255, 0) # Green color for the text
29
+ thickness = 2
30
 
31
+ # Overlay the garment description text on the image
32
+ cv2.putText(result_img, f'Garment: {garment_prompt}', text_position, font, font_scale, font_color, thickness, cv2.LINE_AA)
 
33
 
34
+ post_end_time = time.time()
35
+ print(f"post time used: {post_end_time - post_start_time}")
36
 
37
+ # Return the resulting image, used seed, and success message
38
  return result_img, seed, "Success"
39
 
40
+ MAX_SEED = 999999
 
 
41
 
42
+ example_path = os.path.join(os.path.dirname(__file__), 'assets')
 
 
43
 
44
+ human_list = os.listdir(os.path.join(example_path, "human"))
45
+ human_list_path = [os.path.join(example_path, "human", human) for human in human_list]
46
 
47
  css = """
48
  #col-left {
 
61
  margin: 0 auto;
62
  max-width: 1100px;
63
  }
64
+ #button {
65
+ color: blue;
66
+ }
67
  """
68
 
69
+ def load_description(fp):
70
+ with open(fp, 'r', encoding='utf-8') as f:
71
+ content = f.read()
72
+ return content
73
+
74
+
75
  with gr.Blocks(css=css) as Tryon:
76
+ gr.HTML(load_description("assets/title.md"))
 
77
  with gr.Row():
78
  with gr.Column(elem_id="col-left"):
79
+ gr.HTML("""
80
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
81
+ <div>
82
+ Step 1. Upload a person image ⬇️
83
+ </div>
84
+ </div>
85
+ """)
86
  with gr.Column(elem_id="col-mid"):
87
+ gr.HTML("""
88
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
89
+ <div>
90
+ Step 2. Enter a text prompt for the garment ⬇️
91
+ </div>
92
+ </div>
93
+ """)
94
  with gr.Column(elem_id="col-right"):
95
+ gr.HTML("""
96
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
97
+ <div>
98
+ Step 3. Press “Run” to get try-on results
99
+ </div>
100
+ </div>
101
+ """)
102
+ with gr.Row():
103
+ with gr.Column(elem_id="col-left"):
104
+ imgs = gr.Image(label="Person image", sources='upload', type="numpy")
105
+ example = gr.Examples(
106
+ inputs=imgs,
107
+ examples_per_page=12,
108
+ examples=human_list_path
109
+ )
110
+ with gr.Column(elem_id="col-mid"):
111
+ garment_prompt = gr.Textbox(label="Garment text prompt", placeholder="Describe the garment...")
112
+ with gr.Column(elem_id="col-right"):
113
+ image_out = gr.Image(label="Result", show_share_button=False)
114
+ with gr.Row():
115
+ seed = gr.Slider(
116
+ label="Seed",
117
+ minimum=0,
118
+ maximum=MAX_SEED,
119
+ step=1,
120
+ value=0,
121
+ )
122
+ randomize_seed = gr.Checkbox(label="Random seed", value=True)
123
+ with gr.Row():
124
+ seed_used = gr.Number(label="Seed used")
125
+ result_info = gr.Text(label="Response")
126
+ test_button = gr.Button(value="Run", elem_id="button")
127
+
128
+ test_button.click(fn=tryon, inputs=[imgs, garment_prompt, seed, randomize_seed], outputs=[image_out, seed_used, result_info], concurrency_limit=40)
129
+
130
+ with gr.Column(elem_id="col-showcase"):
131
+ gr.HTML("""
132
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
133
+ <div> </div>
134
+ <br>
135
+ <div>
136
+ Virtual try-on examples in pairs of person and garment images
137
+ </div>
138
+ </div>
139
+ """)
140
+ show_case = gr.Examples(
141
+ examples=[
142
+ ["assets/examples/model2.png", "assets/examples/garment2.png", "assets/examples/result2.png"],
143
+ ["assets/examples/model3.png", "assets/examples/garment3.png", "assets/examples/result3.png"],
144
+ ["assets/examples/model1.png", "assets/examples/garment1.png", "assets/examples/result1.png"],
145
+ ],
146
+ inputs=[imgs, garment_prompt, image_out],
147
+ label=None
148
+ )
149
 
150
  Tryon.launch()