hahafofo commited on
Commit
80fefdb
1 Parent(s): 890361d
Files changed (2) hide show
  1. app.py +83 -124
  2. utils/generator.py +171 -0
app.py CHANGED
@@ -1,124 +1,76 @@
1
- import random
2
- import re
3
-
4
  import gradio as gr
5
  import torch
6
- from transformers import AutoModelForCausalLM, AutoTokenizer
7
- from transformers import pipeline, set_seed
8
 
 
 
9
  from utils.image2text import git_image2text, w14_image2text, clip_image2text
10
- from utils.singleton import Singleton
11
  from utils.translate import en2zh as translate_en2zh
12
  from utils.translate import zh2en as translate_zh2en
13
- from utils.exif import get_image_info
14
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
 
18
- @Singleton
19
- class Models(object):
20
-
21
- def __getattr__(self, item):
22
- if item in self.__dict__:
23
- return getattr(self, item)
24
-
25
- if item in ('big_model', 'big_processor'):
26
- self.big_model, self.big_processor = self.load_image2text_model()
27
-
28
- if item in ('prompter_model', 'prompter_tokenizer'):
29
- self.prompter_model, self.prompter_tokenizer = self.load_prompter_model()
30
-
31
- if item in ('text_pipe',):
32
- self.text_pipe = self.load_text_generation_pipeline()
33
-
34
- return getattr(self, item)
35
-
36
- @classmethod
37
- def load_text_generation_pipeline(cls):
38
- return pipeline('text-generation', model='succinctly/text2image-prompt-generator')
39
-
40
- @classmethod
41
- def load_prompter_model(cls):
42
- prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
43
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
44
- tokenizer.pad_token = tokenizer.eos_token
45
- tokenizer.padding_side = "left"
46
- return prompter_model, tokenizer
47
-
48
-
49
- models = Models.instance()
50
-
51
-
52
- def generate_prompter(plain_text, max_new_tokens=75, num_beams=8, num_return_sequences=8, length_penalty=-1.0):
53
- input_ids = models.prompter_tokenizer(plain_text.strip() + " Rephrase:", return_tensors="pt").input_ids
54
- eos_id = models.prompter_tokenizer.eos_token_id
55
- outputs = models.prompter_model.generate(
56
- input_ids,
57
- do_sample=False,
58
- max_new_tokens=max_new_tokens,
59
- num_beams=num_beams,
60
- num_return_sequences=num_return_sequences,
61
- eos_token_id=eos_id,
62
- pad_token_id=eos_id,
63
- length_penalty=length_penalty
64
  )
65
- output_texts = models.prompter_tokenizer.batch_decode(outputs, skip_special_tokens=True)
66
- result = []
67
- for output_text in output_texts:
68
- result.append(output_text.replace(plain_text + " Rephrase:", "").strip())
69
-
70
- return "\n".join(result)
71
 
72
 
73
  def image_generate_prompter(
74
  bclip_text,
75
  w14_text,
76
- max_new_tokens=75,
77
- num_beams=8,
78
- num_return_sequences=8,
79
- length_penalty=-1.0
 
80
  ):
81
- result = generate_prompter(
82
- bclip_text,
83
- max_new_tokens,
84
- num_beams,
85
- num_return_sequences,
86
- length_penalty
87
  )
88
- return "\n".join(["{},{}".format(line.strip(), w14_text.strip()) for line in result.split("\n") if len(line) > 0])
89
-
90
-
91
- def text_generate(text_in_english):
92
- seed = random.randint(100, 1000000)
93
- set_seed(seed)
94
-
95
- result = ""
96
- for _ in range(6):
97
- sequences = models.text_pipe(text_in_english, max_length=random.randint(60, 90), num_return_sequences=8)
98
- list = []
99
- for sequence in sequences:
100
- line = sequence['generated_text'].strip()
101
- if line != text_in_english and len(line) > (len(text_in_english) + 4) and line.endswith(
102
- (':', '-', '—')) is False:
103
- list.append(line)
104
-
105
- result = "\n".join(list)
106
- result = re.sub('[^ ]+\.[^ ]+', '', result)
107
- result = result.replace('<', '').replace('>', '')
108
- if result != '':
109
- break
110
- return result, "\n".join(translate_en2zh(line) for line in result.split("\n") if len(line) > 0)
111
 
112
 
113
  with gr.Blocks(title="Prompt生成器") as block:
114
  with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  with gr.Tab('从图片中生成'):
117
  with gr.Row():
118
  input_image = gr.Image(type='pil')
119
  exif_info = gr.HTML()
120
- output_blip_or_clip = gr.Textbox(label='生成的 Prompt')
121
- output_w14 = gr.Textbox(label='W14的 Prompt')
122
 
123
  with gr.Accordion('W14', open=False):
124
  w14_raw_output = gr.Textbox(label="Output (raw string)")
@@ -126,38 +78,35 @@ with gr.Blocks(title="Prompt生成器") as block:
126
  w14_rating_output = gr.Label(label="Rating")
127
  w14_characters_output = gr.Label(label="Output (characters)")
128
  w14_tags_output = gr.Label(label="Output (tags)")
129
- images_generate_prompter_output = gr.Textbox(lines=6, label='SD优化的 Prompt')
 
130
  with gr.Row():
131
  img_exif_btn = gr.Button('EXIF')
132
  img_blip_btn = gr.Button('BLIP图片转描述')
133
  img_w14_btn = gr.Button('W14图片转描述')
134
  img_clip_btn = gr.Button('CLIP图片转描述')
135
- img_prompter_btn = gr.Button('SD优化')
136
-
137
- with gr.Tab('文本生成'):
138
- with gr.Row():
139
- input_text = gr.Textbox(lines=6, label='你的想法', placeholder='在此输入内容...')
140
- translate_output = gr.Textbox(lines=6, label='翻译结果(Prompt输入)')
141
 
142
- generate_prompter_output = gr.Textbox(lines=6, label='SD优化的 Prompt')
143
-
144
- output = gr.Textbox(lines=6, label='瞎编的 Prompt')
145
- output_zh = gr.Textbox(lines=6, label='瞎编的 Prompt(zh)')
146
- with gr.Row():
147
- translate_btn = gr.Button('翻译')
148
- generate_prompter_btn = gr.Button('SD优化')
149
- gpt_btn = gr.Button('瞎编')
150
  with gr.Tab('参数设置'):
151
- with gr.Accordion('SD优化参数', open=True):
152
- max_new_tokens = gr.Slider(1, 512, 100, label='max_new_tokens', step=1)
153
- nub_beams = gr.Slider(1, 30, 6, label='num_beams', step=1)
154
- num_return_sequences = gr.Slider(1, 30, 6, label='num_return_sequences', step=1)
155
- length_penalty = gr.Slider(-1.0, 1.0, -1.0, label='length_penalty')
 
 
 
 
 
 
 
 
 
156
  with gr.Accordion('BLIP参数', open=True):
157
  blip_max_length = gr.Slider(1, 512, 100, label='max_length', step=1)
158
  with gr.Accordion('CLIP参数', open=True):
159
  clip_mode_type = gr.Radio(['best', 'classic', 'fast', 'negative'], value='best', label='mode_type')
160
- clip_model_name = gr.Radio(['vit_h_14', 'vit_l_14', ], value='vit_h_14', )
161
  with gr.Accordion('WD14参数', open=True):
162
  image2text_model = gr.Radio(
163
  [
@@ -185,22 +134,32 @@ with gr.Blocks(title="Prompt生成器") as block:
185
  )
186
  img_prompter_btn.click(
187
  fn=image_generate_prompter,
188
- inputs=[output_blip_or_clip, output_w14, max_new_tokens, nub_beams, num_return_sequences, length_penalty],
189
- outputs=images_generate_prompter_output,
 
 
 
 
 
 
 
 
190
  )
191
  translate_btn.click(
192
  fn=translate_zh2en,
193
  inputs=input_text,
194
  outputs=translate_output
195
  )
 
196
  generate_prompter_btn.click(
197
- fn=generate_prompter,
198
- inputs=[translate_output, max_new_tokens, nub_beams, num_return_sequences, length_penalty],
199
- outputs=generate_prompter_output
200
- )
201
- gpt_btn.click(
202
- fn=text_generate,
203
- inputs=translate_output,
 
204
  outputs=[output, output_zh]
205
  )
206
  img_w14_btn.click(
 
 
 
 
1
  import gradio as gr
2
  import torch
 
 
3
 
4
+ from utils.exif import get_image_info
5
+ from utils.generator import generate_prompt
6
  from utils.image2text import git_image2text, w14_image2text, clip_image2text
 
7
  from utils.translate import en2zh as translate_en2zh
8
  from utils.translate import zh2en as translate_zh2en
 
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
 
13
+ def text_generate_prompter(
14
+ plain_text,
15
+ model_name='microsoft',
16
+ prompt_min_length=60,
17
+ prompt_max_length=75,
18
+ prompt_num_return_sequences=8,
19
+ ):
20
+ result = generate_prompt(
21
+ plain_text=plain_text,
22
+ model_name=model_name,
23
+ min_length=prompt_min_length,
24
+ max_length=prompt_max_length,
25
+ num_return_sequences=prompt_num_return_sequences
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  )
27
+ return result, "\n".join(translate_en2zh(line) for line in result.split("\n") if len(line) > 0)
 
 
 
 
 
28
 
29
 
30
  def image_generate_prompter(
31
  bclip_text,
32
  w14_text,
33
+ model_name='microsoft',
34
+ prompt_min_length=60,
35
+ prompt_max_length=75,
36
+ prompt_num_return_sequences=8,
37
+
38
  ):
39
+ result = generate_prompt(
40
+ plain_text=bclip_text,
41
+ model_name=model_name,
42
+ min_length=prompt_min_length,
43
+ max_length=prompt_max_length,
44
+ num_return_sequences=prompt_num_return_sequences
45
  )
46
+ prompter_list = ["{},{}".format(line.strip(), w14_text.strip()) for line in result.split("\n") if len(line) > 0]
47
+ prompter_zh_list = [
48
+ "{},{}".format(translate_en2zh(line.strip()), translate_en2zh(w14_text.strip())) for line in
49
+ result.split("\n") if len(line) > 0
50
+ ]
51
+ return "\n".join(prompter_list), "\n".join(prompter_zh_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
 
54
  with gr.Blocks(title="Prompt生成器") as block:
55
  with gr.Column():
56
+ with gr.Tab('文本生成'):
57
+ with gr.Row():
58
+ input_text = gr.Textbox(lines=6, label='你的想法', placeholder='在此输入内容...')
59
+ translate_output = gr.Textbox(lines=6, label='翻译结果(Prompt输入)')
60
+
61
+ output = gr.Textbox(lines=6, label='优化的 Prompt')
62
+ output_zh = gr.Textbox(lines=6, label='优化的 Prompt(zh)')
63
+ with gr.Row():
64
+ translate_btn = gr.Button('翻译')
65
+
66
+ generate_prompter_btn = gr.Button('优化Prompt')
67
 
68
  with gr.Tab('从图片中生成'):
69
  with gr.Row():
70
  input_image = gr.Image(type='pil')
71
  exif_info = gr.HTML()
72
+ output_blip_or_clip = gr.Textbox(label='生成的 Prompt', lines=4)
73
+ output_w14 = gr.Textbox(label='W14的 Prompt', lines=4)
74
 
75
  with gr.Accordion('W14', open=False):
76
  w14_raw_output = gr.Textbox(label="Output (raw string)")
 
78
  w14_rating_output = gr.Label(label="Rating")
79
  w14_characters_output = gr.Label(label="Output (characters)")
80
  w14_tags_output = gr.Label(label="Output (tags)")
81
+ output_img_prompter = gr.Textbox(lines=6, label='优化的 Prompt')
82
+ output_img_prompter_zh = gr.Textbox(lines=6, label='优化的 Prompt(zh)')
83
  with gr.Row():
84
  img_exif_btn = gr.Button('EXIF')
85
  img_blip_btn = gr.Button('BLIP图片转描述')
86
  img_w14_btn = gr.Button('W14图片转描述')
87
  img_clip_btn = gr.Button('CLIP图片转描述')
88
+ img_prompter_btn = gr.Button('优化Prompt')
 
 
 
 
 
89
 
 
 
 
 
 
 
 
 
90
  with gr.Tab('参数设置'):
91
+ with gr.Accordion('Prompt优化参数', open=True):
92
+ prompt_mode_name = gr.Radio(
93
+ [
94
+ 'microsoft',
95
+ 'mj',
96
+ 'gpt2_650k',
97
+ ],
98
+ value='gpt2_650k',
99
+ label='model_name'
100
+ )
101
+ prompt_min_length = gr.Slider(1, 512, 100, label='min_length', step=1)
102
+ prompt_max_length = gr.Slider(1, 512, 200, label='max_length', step=1)
103
+ prompt_num_return_sequences = gr.Slider(1, 30, 6, label='num_return_sequences', step=1)
104
+
105
  with gr.Accordion('BLIP参数', open=True):
106
  blip_max_length = gr.Slider(1, 512, 100, label='max_length', step=1)
107
  with gr.Accordion('CLIP参数', open=True):
108
  clip_mode_type = gr.Radio(['best', 'classic', 'fast', 'negative'], value='best', label='mode_type')
109
+ clip_model_name = gr.Radio(['vit_h_14', 'vit_l_14', ], value='vit_h_14', label='model_name')
110
  with gr.Accordion('WD14参数', open=True):
111
  image2text_model = gr.Radio(
112
  [
 
134
  )
135
  img_prompter_btn.click(
136
  fn=image_generate_prompter,
137
+ inputs=[
138
+ output_blip_or_clip,
139
+ output_w14,
140
+ prompt_mode_name,
141
+ prompt_min_length,
142
+ prompt_max_length,
143
+ prompt_num_return_sequences,
144
+
145
+ ],
146
+ outputs=[output_img_prompter, output_img_prompter_zh]
147
  )
148
  translate_btn.click(
149
  fn=translate_zh2en,
150
  inputs=input_text,
151
  outputs=translate_output
152
  )
153
+
154
  generate_prompter_btn.click(
155
+ fn=text_generate_prompter,
156
+ inputs=[
157
+ translate_output,
158
+ prompt_mode_name,
159
+ prompt_min_length,
160
+ prompt_max_length,
161
+ prompt_num_return_sequences,
162
+ ],
163
  outputs=[output, output_zh]
164
  )
165
  img_w14_btn.click(
utils/generator.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from transformers import pipeline, set_seed
4
+ import random
5
+ import re
6
+ from .singleton import Singleton
7
+
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+
11
+ @Singleton
12
+ class Models(object):
13
+
14
+ def __getattr__(self, item):
15
+ if item in self.__dict__:
16
+ return getattr(self, item)
17
+
18
+ if item in ('microsoft_model', 'microsoft_tokenizer'):
19
+ self.microsoft_model, self.microsoft_tokenizer = self.load_microsoft_model()
20
+
21
+ if item in ('mj_pipe',):
22
+ self.mj_pipe = self.load_mj_pipe()
23
+
24
+ if item in ('gpt2_650k_pipe',):
25
+ self.gpt2_650k_pipe = self.load_gpt2_650k_pipe()
26
+
27
+ return getattr(self, item)
28
+
29
+ @classmethod
30
+ def load_gpt2_650k_pipe(cls):
31
+
32
+ return pipeline('text-generation', model='Ar4ikov/gpt2-650k-stable-diffusion-prompt-generator')
33
+
34
+ @classmethod
35
+ def load_mj_pipe(cls):
36
+ return pipeline('text-generation', model='succinctly/text2image-prompt-generator')
37
+
38
+ @classmethod
39
+ def load_microsoft_model(cls):
40
+ prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
41
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
42
+ tokenizer.pad_token = tokenizer.eos_token
43
+ tokenizer.padding_side = "left"
44
+ return prompter_model, tokenizer
45
+
46
+
47
+ models = Models.instance()
48
+
49
+
50
+ def rand_length(min_length: int = 60, max_length: int = 90) -> int:
51
+ if min_length > max_length:
52
+ return max_length
53
+
54
+ return random.randint(min_length, max_length)
55
+
56
+
57
+ def generate_prompt(
58
+ plain_text,
59
+ min_length=60,
60
+ max_length=90,
61
+ num_return_sequences=8,
62
+ model_name='microsoft',
63
+ ):
64
+ if model_name == 'gpt2_650k':
65
+ return generate_prompt_gpt2_650k(
66
+ prompt=plain_text,
67
+ min_length=min_length,
68
+ max_length=max_length,
69
+ num_return_sequences=num_return_sequences,
70
+ )
71
+ elif model_name == 'mj':
72
+ return generate_prompt_mj(
73
+ text_in_english=plain_text,
74
+ num_return_sequences=num_return_sequences,
75
+ min_length=min_length,
76
+ max_length=max_length,
77
+ )
78
+ else:
79
+ return generate_prompt_microsoft(
80
+ plain_text=plain_text,
81
+ min_length=min_length,
82
+ max_length=max_length,
83
+ num_return_sequences=num_return_sequences,
84
+ num_beams=num_return_sequences,
85
+ )
86
+
87
+
88
+ def generate_prompt_microsoft(
89
+ plain_text,
90
+ min_length=60,
91
+ max_length=90,
92
+ num_beams=8,
93
+ num_return_sequences=8,
94
+ length_penalty=-1.0
95
+ ) -> str:
96
+ input_ids = models.microsoft_tokenizer(plain_text.strip() + " Rephrase:", return_tensors="pt").input_ids
97
+ eos_id = models.microsoft_tokenizer.eos_token_id
98
+
99
+ outputs = models.microsoft_model.generate(
100
+ input_ids,
101
+ do_sample=False,
102
+ max_new_tokens=rand_length(min_length, max_length),
103
+ num_beams=num_beams,
104
+ num_return_sequences=num_return_sequences,
105
+ eos_token_id=eos_id,
106
+ pad_token_id=eos_id,
107
+ length_penalty=length_penalty
108
+ )
109
+ output_texts = models.microsoft_tokenizer.batch_decode(outputs, skip_special_tokens=True)
110
+ result = []
111
+ for output_text in output_texts:
112
+ result.append(output_text.replace(plain_text + " Rephrase:", "").strip())
113
+
114
+ return "\n".join(result)
115
+
116
+
117
+ def generate_prompt_gpt2_650k(prompt: str, min_length=60, max_length: int = 255, num_return_sequences: int = 8) -> str:
118
+ def get_valid_prompt(text: str) -> str:
119
+ dot_split = text.split('.')[0]
120
+ n_split = text.split('\n')[0]
121
+
122
+ return {
123
+ len(dot_split) < len(n_split): dot_split,
124
+ len(n_split) > len(dot_split): n_split,
125
+ len(n_split) == len(dot_split): dot_split
126
+ }[True]
127
+
128
+ output = []
129
+ for _ in range(6):
130
+
131
+ output += [
132
+ get_valid_prompt(result['generated_text']) for result in
133
+ models.gpt2_650k_pipe(
134
+ prompt,
135
+ max_new_tokens=rand_length(min_length, max_length),
136
+ num_return_sequences=num_return_sequences
137
+ )
138
+ ]
139
+ output = list(set(output))
140
+ if len(output) >= num_return_sequences:
141
+ break
142
+
143
+ # valid_prompt = get_valid_prompt(models.gpt2_650k_pipe(prompt, max_length=max_length)[0]['generated_text'])
144
+ return "\n".join([o.strip() for o in output])
145
+
146
+
147
+ def generate_prompt_mj(text_in_english: str, num_return_sequences: int = 8, min_length=60, max_length=90) -> str:
148
+ seed = random.randint(100, 1000000)
149
+ set_seed(seed)
150
+
151
+ result = ""
152
+ for _ in range(6):
153
+ sequences = models.mj_pipe(
154
+ text_in_english,
155
+ max_new_tokens=rand_length(min_length, max_length),
156
+ num_return_sequences=num_return_sequences
157
+ )
158
+ list = []
159
+ for sequence in sequences:
160
+ line = sequence['generated_text'].strip()
161
+ if line != text_in_english and len(line) > (len(text_in_english) + 4) and line.endswith(
162
+ (':', '-', '—')) is False:
163
+ list.append(line)
164
+
165
+ result = "\n".join(list)
166
+ result = re.sub('[^ ]+\.[^ ]+', '', result)
167
+ result = result.replace('<', '').replace('>', '')
168
+ if result != '':
169
+ break
170
+ return result
171
+ # return result, "\n".join(translate_en2zh(line) for line in result.split("\n") if len(line) > 0)