Singularity666 commited on
Commit
66f79e7
·
verified ·
1 Parent(s): c10dc95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -38
app.py CHANGED
@@ -1,46 +1,65 @@
1
- # app.py
2
-
3
- import os
4
  import gradio as gr
5
- from huggingface_hub import HfFolder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- def launch_gradio_app(fine_tune_model, load_model, generate_images, push_to_huggingface, repo_name):
8
  with gr.Blocks() as demo:
9
- gr.Markdown("# Dreambooth App")
10
-
11
- with gr.Tab("Fine-tune Model"):
12
- with gr.Row():
13
- instance_images = gr.File(label="Instance Images", file_count="multiple")
14
- class_images = gr.File(label="Class Images", file_count="multiple")
15
- with gr.Row():
16
- instance_prompt = gr.Textbox(label="Instance Prompt")
17
- class_prompt = gr.Textbox(label="Class Prompt")
18
  with gr.Row():
19
- num_train_steps = gr.Number(label="Number of Training Steps", value=800)
20
- fine_tune_button = gr.Button("Fine-tune Model")
 
 
 
 
21
 
22
  with gr.Tab("Generate Images"):
23
  with gr.Row():
24
- prompt = gr.Textbox(label="Prompt")
25
- negative_prompt = gr.Textbox(label="Negative Prompt")
26
- with gr.Row():
27
- num_samples = gr.Number(label="Number of Samples", value=1)
28
- guidance_scale = gr.Number(label="Guidance Scale", value=7.5)
29
- with gr.Row():
30
- height = gr.Number(label="Height", value=512)
31
- width = gr.Number(label="Width", value=512)
32
- num_inference_steps = gr.Slider(label="Number of Inference Steps", value=50, minimum=1, maximum=100)
33
- generate_button = gr.Button("Generate Images")
34
- output_images = gr.Gallery()
35
 
36
- with gr.Tab("Push to Hugging Face"):
37
- push_button = gr.Button("Push Model to Hugging Face")
38
- huggingface_link = gr.Textbox(label="Hugging Face Model Link")
39
-
40
- fine_tune_button.click(fine_tune_model, inputs=[instance_images, class_images, instance_prompt, class_prompt, num_train_steps], outputs=huggingface_link)
41
-
42
- generate_button.click(generate_images, inputs=[prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale], outputs=output_images)
43
-
44
- push_button.click(push_to_huggingface, inputs=[HfFolder.path, repo_name], outputs=huggingface_link)
45
-
46
- demo.launch()
 
 
 
 
1
  import gradio as gr
2
+ import os
3
+ import shutil
4
+ from pathlib import Path
5
+ from main import fine_tune_model
6
+ from diffusers import StableDiffusionPipeline, DDIMScheduler
7
+ import torch
8
+
9
+ MODEL_NAME = "runwayml/stable-diffusion-v1-5"
10
+ OUTPUT_DIR = "/content/stable_diffusion_weights/custom_model"
11
+
12
+ def fine_tune(instance_prompt, images):
13
+ instance_data_dir = "/content/instance_images"
14
+ if os.path.exists(instance_data_dir):
15
+ shutil.rmtree(instance_data_dir)
16
+ os.makedirs(instance_data_dir, exist_ok=True)
17
+
18
+ for i, img in enumerate(images):
19
+ img.save(os.path.join(instance_data_dir, f"instance_{i}.png"))
20
+
21
+ fine_tune_model(instance_data_dir, instance_prompt, MODEL_NAME, OUTPUT_DIR)
22
+ return "Model fine-tuning complete."
23
+
24
+ def generate_images(prompt, num_samples, height, width, num_inference_steps, guidance_scale):
25
+ pipe = StableDiffusionPipeline.from_pretrained(OUTPUT_DIR, safety_checker=None, torch_dtype=torch.float16).to("cuda")
26
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
27
+ g_cuda = torch.Generator(device='cuda').manual_seed(1337)
28
+
29
+ with torch.autocast("cuda"), torch.inference_mode():
30
+ images = pipe(
31
+ prompt, height=height, width=width, num_images_per_prompt=num_samples,
32
+ num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=g_cuda
33
+ ).images
34
+
35
+ return images
36
 
37
+ def gradio_app():
38
  with gr.Blocks() as demo:
39
+ with gr.Tab("Fine-Tune Model"):
 
 
 
 
 
 
 
 
40
  with gr.Row():
41
+ with gr.Column():
42
+ instance_prompt = gr.Textbox(label="Instance Prompt")
43
+ image_input = gr.Image(label="Upload Images", source="upload", tool="editor", type="pil", multiple=True)
44
+ fine_tune_button = gr.Button("Fine-Tune Model")
45
+ output_text = gr.Textbox(label="Output")
46
+ fine_tune_button.click(fine_tune, inputs=[instance_prompt, image_input], outputs=output_text)
47
 
48
  with gr.Tab("Generate Images"):
49
  with gr.Row():
50
+ with gr.Column():
51
+ prompt = gr.Textbox(label="Prompt")
52
+ num_samples = gr.Number(label="Number of Samples", value=1)
53
+ guidance_scale = gr.Number(label="Guidance Scale", value=7.5)
54
+ height = gr.Number(label="Height", value=512)
55
+ width = gr.Number(label="Width", value=512)
56
+ num_inference_steps = gr.Slider(label="Steps", value=50, minimum=1, maximum=100)
57
+ generate_button = gr.Button("Generate Images")
58
+ with gr.Column():
59
+ gallery = gr.Gallery(label="Generated Images")
60
+ generate_button.click(generate_images, inputs=[prompt, num_samples, height, width, num_inference_steps, guidance_scale], outputs=gallery)
61
 
62
+ demo.launch()
63
+
64
+ if __name__ == "__main__":
65
+ gradio_app()