PanigrahiNirma commited on
Commit
22a8243
·
verified ·
1 Parent(s): 1d600db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -11,7 +11,7 @@ import torch
11
  DEFAULT_BLUR_RADIUS = 2
12
  DEFAULT_ERODE_ITERATIONS = 1
13
  DEFAULT_TOLERANCE = 2
14
- DEFAULT_INFERENCE_STEPS = 30 # Reduced inference steps
15
 
16
  # Initialize pipeline with optimizations
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -62,14 +62,14 @@ def image_to_svg_paths(image, tolerance=DEFAULT_TOLERANCE, blur_radius=DEFAULT_B
62
  svg_paths.append(path)
63
  return svg_paths
64
 
65
- async def generate_svgs(prompt, color, num_images, fill_color, stroke_color, stroke_width, tolerance, blur_radius, erode_iterations, progress=gr.Progress()):
66
  """Generate SVGs from a text prompt."""
67
  color_prompt = f"with {color} colors" if color != "Random" else ""
68
  full_prompt = f"{prompt}, {color_prompt}"
69
  svg_outputs = []
70
- for i in progress.tqdm(range(num_images)):
71
  try:
72
- image = pipe(full_prompt, num_inference_steps=DEFAULT_INFERENCE_STEPS, height=256, width=256).images[0].convert("RGB")
73
  svg_paths = image_to_svg_paths(image, tolerance=tolerance, blur_radius=blur_radius, erode_iterations=erode_iterations)
74
  svg_string = f'<svg width="256" height="256" xmlns="http://www.w3.org/2000/svg">'
75
  for path in svg_paths:
@@ -82,7 +82,7 @@ async def generate_svgs(prompt, color, num_images, fill_color, stroke_color, str
82
  svg_outputs.append(f"Error: {str(e)}")
83
 
84
  # Ensure a minimum number of outputs
85
- while len(svg_outputs) < 4:
86
  svg_outputs.append("No Image Generated")
87
  return svg_outputs
88
 
@@ -92,13 +92,13 @@ with gr.Blocks() as demo:
92
  with gr.Column():
93
  prompt_input = gr.Textbox(label="Describe the image", lines=3, placeholder="A cat playing with a ball, vibrant colors")
94
  color_input = gr.Dropdown(choices=list(COLORS.keys()), label="Desired Colors", value="Random")
95
- num_images_input = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="Number of Images")
96
  fill_color_input = gr.ColorPicker(label="Fill Color", value="#000000")
97
  stroke_color_input = gr.ColorPicker(label="Stroke Color", value="#000000")
98
  stroke_width_input = gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Stroke Width")
99
  tolerance_input = gr.Slider(minimum=1, maximum=10, value=DEFAULT_TOLERANCE, step=1, label="Contour Tolerance")
100
  blur_radius_input = gr.Slider(minimum=1, maximum=5, value=DEFAULT_BLUR_RADIUS, step=1, label="Blur Radius")
101
  erode_iterations_input = gr.Slider(minimum=1, maximum=3, value=DEFAULT_ERODE_ITERATIONS, step=1, label="Erode Iterations")
 
102
  generate_button = gr.Button("Generate")
103
  gr.Markdown("## Prompt Engineering Tips:\n- Be specific and descriptive.\n- Use keywords related to style, color, and composition.\n- Experiment with different prompts to find what works best.")
104
 
@@ -106,12 +106,11 @@ with gr.Blocks() as demo:
106
  svg_output1 = gr.HTML(label="SVG 1")
107
  svg_output2 = gr.HTML(label="SVG 2")
108
  svg_output3 = gr.HTML(label="SVG 3")
109
- svg_output4 = gr.HTML(label="SVG 4")
110
 
111
  generate_button.click(
112
  fn=generate_svgs,
113
- inputs=[prompt_input, color_input, num_images_input, fill_color_input, stroke_color_input, stroke_width_input, tolerance_input, blur_radius_input, erode_iterations_input],
114
- outputs=[svg_output1, svg_output2, svg_output3, svg_output4],
115
  )
116
 
117
  demo.launch(share=True)
 
11
  DEFAULT_BLUR_RADIUS = 2
12
  DEFAULT_ERODE_ITERATIONS = 1
13
  DEFAULT_TOLERANCE = 2
14
+ DEFAULT_INFERENCE_STEPS = 20 # Reduced inference steps
15
 
16
  # Initialize pipeline with optimizations
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
62
  svg_paths.append(path)
63
  return svg_paths
64
 
65
+ async def generate_svgs(prompt, color, fill_color, stroke_color, stroke_width, tolerance, blur_radius, erode_iterations, inference_steps, progress=gr.Progress()):
66
  """Generate SVGs from a text prompt."""
67
  color_prompt = f"with {color} colors" if color != "Random" else ""
68
  full_prompt = f"{prompt}, {color_prompt}"
69
  svg_outputs = []
70
+ for i in progress.tqdm(range(3)): # Fixed to 3 generations
71
  try:
72
+ image = pipe(full_prompt, num_inference_steps=inference_steps, height=256, width=256).images[0].convert("RGB")
73
  svg_paths = image_to_svg_paths(image, tolerance=tolerance, blur_radius=blur_radius, erode_iterations=erode_iterations)
74
  svg_string = f'<svg width="256" height="256" xmlns="http://www.w3.org/2000/svg">'
75
  for path in svg_paths:
 
82
  svg_outputs.append(f"Error: {str(e)}")
83
 
84
  # Ensure a minimum number of outputs
85
+ while len(svg_outputs) < 3:
86
  svg_outputs.append("No Image Generated")
87
  return svg_outputs
88
 
 
92
  with gr.Column():
93
  prompt_input = gr.Textbox(label="Describe the image", lines=3, placeholder="A cat playing with a ball, vibrant colors")
94
  color_input = gr.Dropdown(choices=list(COLORS.keys()), label="Desired Colors", value="Random")
 
95
  fill_color_input = gr.ColorPicker(label="Fill Color", value="#000000")
96
  stroke_color_input = gr.ColorPicker(label="Stroke Color", value="#000000")
97
  stroke_width_input = gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Stroke Width")
98
  tolerance_input = gr.Slider(minimum=1, maximum=10, value=DEFAULT_TOLERANCE, step=1, label="Contour Tolerance")
99
  blur_radius_input = gr.Slider(minimum=1, maximum=5, value=DEFAULT_BLUR_RADIUS, step=1, label="Blur Radius")
100
  erode_iterations_input = gr.Slider(minimum=1, maximum=3, value=DEFAULT_ERODE_ITERATIONS, step=1, label="Erode Iterations")
101
+ inference_steps_input = gr.Slider(minimum=10, maximum=50, value=DEFAULT_INFERENCE_STEPS, step=5, label="Inference Steps")
102
  generate_button = gr.Button("Generate")
103
  gr.Markdown("## Prompt Engineering Tips:\n- Be specific and descriptive.\n- Use keywords related to style, color, and composition.\n- Experiment with different prompts to find what works best.")
104
 
 
106
  svg_output1 = gr.HTML(label="SVG 1")
107
  svg_output2 = gr.HTML(label="SVG 2")
108
  svg_output3 = gr.HTML(label="SVG 3")
 
109
 
110
  generate_button.click(
111
  fn=generate_svgs,
112
+ inputs=[prompt_input, color_input, fill_color_input, stroke_color_input, stroke_width_input, tolerance_input, blur_radius_input, erode_iterations_input, inference_steps_input],
113
+ outputs=[svg_output1, svg_output2, svg_output3],
114
  )
115
 
116
  demo.launch(share=True)