tori29umai commited on
Commit
402fe71
·
1 Parent(s): 3757039
Files changed (1) hide show
  1. app.py +59 -55
app.py CHANGED
@@ -11,6 +11,20 @@ from utils.prompt_utils import remove_color
11
  from utils.tagger import modelLoad, analysis
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def load_model(lora_dir, cn_dir):
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  dtype = torch.float16
@@ -29,30 +43,49 @@ def load_model(lora_dir, cn_dir):
29
  return pipe
30
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  class Img2Img:
33
  def __init__(self):
34
  self.setup_paths()
35
  self.setup_models()
 
36
  self.post_filter = True
37
  self.tagger_model = None
38
  self.input_image_path = None
39
 
40
- def setup_paths(self):
41
- self.path = os.getcwd()
42
- self.cn_dir = f"{self.path}/controlnet"
43
- self.tagger_dir = f"{self.path}/tagger"
44
- self.lora_dir = f"{self.path}/lora"
45
- os.makedirs(self.cn_dir, exist_ok=True)
46
- os.makedirs(self.tagger_dir, exist_ok=True)
47
- os.makedirs(self.lora_dir, exist_ok=True)
48
-
49
- def setup_models(self):
50
- load_cn_model(self.cn_dir)
51
- load_cn_config(self.cn_dir)
52
- load_tagger_model(self.tagger_dir)
53
- load_lora_model(self.lora_dir)
54
-
55
-
56
  def process_prompt_analysis(self, input_image_path):
57
  if self.tagger_model is None:
58
  self.tagger_model = modelLoad(self.tagger_dir)
@@ -63,7 +96,7 @@ class Img2Img:
63
  return tags_list
64
 
65
 
66
- def launch(self):
67
  css = """
68
  #intro{
69
  max-width: 32rem;
@@ -77,8 +110,11 @@ class Img2Img:
77
  self.input_image_path = gr.Image(label="input_image", type='filepath')
78
  self.prompt = gr.Textbox(label="prompt", lines=3)
79
  self.negative_prompt = gr.Textbox(label="negative_prompt", lines=3, value="lowres, error, extra digit, fewer digits, cropped, worst quality,low quality, normal quality, jpeg artifacts, blurry")
 
80
  prompt_analysis_button = gr.Button("prompt解析")
 
81
  self.controlnet_scale = gr.Slider(minimum=0.5, maximum=1.25, value=1.0, step=0.01, label="線画忠実度")
 
82
  generate_button = gr.Button("生成")
83
  with gr.Column():
84
  self.output_image = gr.Image(type="pil", label="出力画像")
@@ -96,41 +132,9 @@ class Img2Img:
96
  inputs=[self.input_image_path, self.prompt, self.negative_prompt, self.controlnet_scale],
97
  outputs=self.output_image
98
  )
99
- self.demo.queue()
100
- self.demo.launch(share=True)
101
-
102
- @spaces.GPU
103
- def predict(self, input_image_path, prompt, negative_prompt, controlnet_scale):
104
- pipe = load_model(self.lora_dir, self.cn_dir)
105
- input_image_pil = Image.open(input_image_path)
106
- base_size = input_image_pil.size
107
- resize_image = resize_image_aspect_ratio(input_image_pil)
108
- resize_image_size = resize_image.size
109
- width, height = resize_image_size
110
- white_base_pil = base_generation(resize_image.size, (255, 255, 255, 255)).convert("RGB")
111
- generator = torch.manual_seed(0)
112
- last_time = time.time()
113
-
114
- output_image = pipe(
115
- image=white_base_pil,
116
- control_image=resize_image,
117
- strength=1.0,
118
- prompt=prompt,
119
- negative_prompt=negative_prompt,
120
- width=width,
121
- height=height,
122
- controlnet_conditioning_scale=float(controlnet_scale),
123
- controlnet_start=0.0,
124
- controlnet_end=1.0,
125
- generator=generator,
126
- num_inference_steps=30,
127
- guidance_scale=8.5,
128
- eta=1.0,
129
- ).images[0]
130
- print(f"Time taken: {time.time() - last_time}")
131
- output_image = output_image.resize(base_size, Image.LANCZOS)
132
- return output_image
133
-
134
- if __name__ == "__main__":
135
- ui = Img2Img()
136
- ui.launch()
 
11
  from utils.tagger import modelLoad, analysis
12
 
13
 
14
+
15
+ path = os.getcwd()
16
+ cn_dir = f"{path}/controlnet"
17
+ tagger_dir = f"{path}/tagger"
18
+ lora_dir = f"{path}/lora"
19
+ os.makedirs(cn_dir, exist_ok=True)
20
+ os.makedirs(tagger_dir, exist_ok=True)
21
+ os.makedirs(lora_dir, exist_ok=True)
22
+
23
+ load_cn_model(cn_dir)
24
+ load_cn_config(cn_dir)
25
+ load_tagger_model(tagger_dir)
26
+ load_lora_model(lora_dir)
27
+
28
  def load_model(lora_dir, cn_dir):
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
  dtype = torch.float16
 
43
  return pipe
44
 
45
 
46
+ @spaces.GPU
47
+ def predict(input_image_path, prompt, negative_prompt, controlnet_scale):
48
+ pipe = load_model(lora_dir, cn_dir)
49
+ input_image_pil = Image.open(input_image_path)
50
+ base_size = input_image_pil.size
51
+ resize_image = resize_image_aspect_ratio(input_image_pil)
52
+ resize_image_size = resize_image.size
53
+ width, height = resize_image_size
54
+ white_base_pil = base_generation(resize_image.size, (255, 255, 255, 255)).convert("RGB")
55
+ generator = torch.manual_seed(0)
56
+ last_time = time.time()
57
+
58
+ output_image = pipe(
59
+ image=white_base_pil,
60
+ control_image=resize_image,
61
+ strength=1.0,
62
+ prompt=prompt,
63
+ negative_prompt = negative_prompt,
64
+ width=width,
65
+ height=height,
66
+ controlnet_conditioning_scale=float(controlnet_scale),
67
+ controlnet_start=0.0,
68
+ controlnet_end=1.0,
69
+ generator=generator,
70
+ num_inference_steps=30,
71
+ guidance_scale=8.5,
72
+ eta=1.0,
73
+ ).images[0]
74
+ print(f"Time taken: {time.time() - last_time}")
75
+ output_image = output_image.resize(base_size, Image.LANCZOS)
76
+ return output_image
77
+
78
+
79
+
80
  class Img2Img:
81
  def __init__(self):
82
  self.setup_paths()
83
  self.setup_models()
84
+ self.demo = self.layout()
85
  self.post_filter = True
86
  self.tagger_model = None
87
  self.input_image_path = None
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def process_prompt_analysis(self, input_image_path):
90
  if self.tagger_model is None:
91
  self.tagger_model = modelLoad(self.tagger_dir)
 
96
  return tags_list
97
 
98
 
99
+ def layout(self):
100
  css = """
101
  #intro{
102
  max-width: 32rem;
 
110
  self.input_image_path = gr.Image(label="input_image", type='filepath')
111
  self.prompt = gr.Textbox(label="prompt", lines=3)
112
  self.negative_prompt = gr.Textbox(label="negative_prompt", lines=3, value="lowres, error, extra digit, fewer digits, cropped, worst quality,low quality, normal quality, jpeg artifacts, blurry")
113
+
114
  prompt_analysis_button = gr.Button("prompt解析")
115
+
116
  self.controlnet_scale = gr.Slider(minimum=0.5, maximum=1.25, value=1.0, step=0.01, label="線画忠実度")
117
+
118
  generate_button = gr.Button("生成")
119
  with gr.Column():
120
  self.output_image = gr.Image(type="pil", label="出力画像")
 
132
  inputs=[self.input_image_path, self.prompt, self.negative_prompt, self.controlnet_scale],
133
  outputs=self.output_image
134
  )
135
+ return demo
136
+
137
+
138
+
139
+ img2img = Img2Img()
140
+ img2img.demo.launch(share=True)