Singularity666 commited on
Commit
b92dd65
·
verified ·
1 Parent(s): fb5828b

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +117 -0
main.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import shutil
4
+ from pathlib import Path
5
+ import torch
6
+ import gradio as gr
7
+ from diffusers import StableDiffusionPipeline, DDIMScheduler
8
+ from transformers import CLIPTextModel, CLIPTokenizer
9
+ from PIL import Image
10
+ from torch import autocast
11
+
12
+ # Define necessary paths and variables
13
+ MODEL_NAME = "runwayml/stable-diffusion-v1-5"
14
+ OUTPUT_DIR = "/output_model"
15
+ INSTANCE_PROMPT = "photo of {identifier} person"
16
+ CLASS_PROMPT = "photo of a person"
17
+ SEED = 1337
18
+ RESOLUTION = 512
19
+ TRAIN_BATCH_SIZE = 1
20
+ LEARNING_RATE = 1e-6
21
+ MAX_TRAIN_STEPS = 800
22
+ GUIDANCE_SCALE = 8.0
23
+ NUM_INFERENCE_STEPS = 50
24
+
25
+ # Function to fine-tune the model
26
+ def fine_tune_model(instance_data_dir, identifier):
27
+ # Set up paths
28
+ instance_prompt = INSTANCE_PROMPT.format(identifier=identifier)
29
+ concepts_list = [
30
+ {
31
+ "instance_prompt": instance_prompt,
32
+ "class_prompt": CLASS_PROMPT,
33
+ "instance_data_dir": instance_data_dir,
34
+ "class_data_dir": "/sample_data/person" # Placeholder for regularization images
35
+ }
36
+ ]
37
+
38
+ # Save concepts_list.json
39
+ with open("concepts_list.json", "w") as f:
40
+ json.dump(concepts_list, f, indent=4)
41
+
42
+ # Run the training script
43
+ os.system(f"""
44
+ python3 train_dreambooth.py \
45
+ --pretrained_model_name_or_path={MODEL_NAME} \
46
+ --output_dir={OUTPUT_DIR} \
47
+ --revision="fp16" \
48
+ --with_prior_preservation --prior_loss_weight=1.0 \
49
+ --seed={SEED} \
50
+ --resolution={RESOLUTION} \
51
+ --train_batch_size={TRAIN_BATCH_SIZE} \
52
+ --train_text_encoder \
53
+ --mixed_precision="fp16" \
54
+ --use_8bit_adam \
55
+ --gradient_accumulation_steps=1 \
56
+ --learning_rate={LEARNING_RATE} \
57
+ --max_train_steps={MAX_TRAIN_STEPS} \
58
+ --save_sample_prompt="{instance_prompt}" \
59
+ --concepts_list="concepts_list.json"
60
+ """)
61
+
62
+ # Function for inference
63
+ def generate_images(prompt, negative_prompt, num_samples, model_path, height=512, width=512, num_inference_steps=50, guidance_scale=7.5):
64
+ pipe = StableDiffusionPipeline.from_pretrained(model_path, safety_checker=None, torch_dtype=torch.float16).to("cuda")
65
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
66
+ pipe.enable_xformers_memory_efficient_attention()
67
+ g_cuda = torch.Generator(device='cuda').manual_seed(SEED)
68
+
69
+ with torch.autocast("cuda"), torch.inference_mode():
70
+ images = pipe(
71
+ prompt,
72
+ height=height,
73
+ width=width,
74
+ negative_prompt=negative_prompt,
75
+ num_images_per_prompt=num_samples,
76
+ num_inference_steps=num_inference_steps,
77
+ guidance_scale=guidance_scale,
78
+ generator=g_cuda
79
+ ).images
80
+
81
+ return images
82
+
83
+ # Gradio UI function
84
+ def inference_ui(identifier, prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale):
85
+ model_path = OUTPUT_DIR
86
+ prompt = INSTANCE_PROMPT.format(identifier=identifier) + ", " + prompt
87
+ images = generate_images(prompt, negative_prompt, num_samples, model_path, height, width, num_inference_steps, guidance_scale)
88
+ return images
89
+
90
+ # Define Gradio interface
91
+ def create_gradio_ui():
92
+ with gr.Blocks() as demo:
93
+ with gr.Row():
94
+ with gr.Column():
95
+ identifier = gr.Textbox(label="Identifier", placeholder="Enter a unique identifier")
96
+ image_upload = gr.File(label="Upload Images", file_count="multiple", type="file")
97
+ finetune_button = gr.Button(value="Fine-Tune Model")
98
+ finetune_output = gr.Textbox(label="Fine-Tuning Output")
99
+
100
+ with gr.Column():
101
+ prompt = gr.Textbox(label="Prompt", value="photo of {identifier} person in a marriage hall")
102
+ negative_prompt = gr.Textbox(label="Negative Prompt", value="")
103
+ num_samples = gr.Number(label="Number of Samples", value=4)
104
+ guidance_scale = gr.Number(label="Guidance Scale", value=8)
105
+ height = gr.Number(label="Height", value=512)
106
+ width = gr.Number(label="Width", value=512)
107
+ num_inference_steps = gr.Slider(label="Steps", value=50)
108
+ generate_button = gr.Button(value="Generate Images")
109
+ gallery = gr.Gallery()
110
+
111
+ finetune_button.click(finetune_model, inputs=[image_upload, identifier], outputs=finetune_output)
112
+ generate_button.click(inference_ui, inputs=[identifier, prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale], outputs=gallery)
113
+
114
+ demo.launch()
115
+
116
+ if __name__ == "__main__":
117
+ create_gradio_ui()