Akumetsu971 commited on
Commit
d0bd9ea
·
1 Parent(s): 3a509a0

Upload 3 files

Browse files
Files changed (3) hide show
  1. VCM07_style.pt +3 -0
  2. VCM07_style2.pt +3 -0
  3. prompt_blending.py +183 -0
VCM07_style.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8de463223c45d273ec77449808d47bce0b6987678ccc71cf3d413beba6ad3a17
3
+ size 25515
VCM07_style2.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f1aec4732e93aa0943a30f1d8c8ec666abc9e94b0d0136d43c58740bd3d510f
3
+ size 25515
prompt_blending.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import modules.scripts as scripts
2
+ import modules.prompt_parser as prompt_parser
3
+ import itertools
4
+ import torch
5
+
6
+
7
+ def hijacked_get_learned_conditioning(model, prompts, steps):
8
+ global real_get_learned_conditioning
9
+
10
+ if not hasattr(model, '__hacked'):
11
+ real_model_func = model.get_learned_conditioning
12
+
13
+ def hijacked_model_func(texts):
14
+ weighted_prompts = list(map(lambda t: get_weighted_prompt((t, 1)), texts))
15
+ all_texts = []
16
+ for weighted_prompt in weighted_prompts:
17
+ for (prompt, weight) in weighted_prompt:
18
+ all_texts.append(prompt)
19
+
20
+ if len(all_texts) > len(texts):
21
+ all_conds = real_model_func(all_texts)
22
+ offset = 0
23
+
24
+ conds = []
25
+
26
+ for weighted_prompt in weighted_prompts:
27
+ c = torch.zeros_like(all_conds[offset])
28
+ for (i, (prompt, weight)) in enumerate(weighted_prompt):
29
+ c = torch.add(c, all_conds[i+offset], alpha=weight)
30
+ conds.append(c)
31
+ offset += len(weighted_prompt)
32
+ return conds
33
+ else:
34
+ return real_model_func(texts)
35
+
36
+ model.get_learned_conditioning = hijacked_model_func
37
+ model.__hacked = True
38
+
39
+ switched_prompts = list(map(lambda p: switch_syntax(p), prompts))
40
+ return real_get_learned_conditioning(model, switched_prompts, steps)
41
+
42
+
43
+ real_get_learned_conditioning = hijacked_get_learned_conditioning # no really, overriden below
44
+
45
+
46
+ class Script(scripts.Script):
47
+ def title(self):
48
+ return "Prompt Blending"
49
+
50
+ def show(self, is_img2img):
51
+ global real_get_learned_conditioning
52
+ if real_get_learned_conditioning == hijacked_get_learned_conditioning:
53
+ real_get_learned_conditioning = prompt_parser.get_learned_conditioning
54
+ prompt_parser.get_learned_conditioning = hijacked_get_learned_conditioning
55
+ return False
56
+
57
+ def ui(self, is_img2img):
58
+ return []
59
+
60
+ def run(self, p, seeds):
61
+ return
62
+
63
+
64
+ OPEN = '{'
65
+ CLOSE = '}'
66
+ SEPARATE = '|'
67
+ MARK = '@'
68
+ REAL_MARK = ':'
69
+
70
+
71
+ def combine(left, right):
72
+ return map(lambda p: (p[0][0] + p[1][0], p[0][1] * p[1][1]), itertools.product(left, right))
73
+
74
+
75
+ def get_weighted_prompt(prompt_weight):
76
+ (prompt, full_weight) = prompt_weight
77
+ results = [('', full_weight)]
78
+ alts = []
79
+ start = 0
80
+ mark = -1
81
+ open_count = 0
82
+ first_open = 0
83
+ nested = False
84
+
85
+ for i, c in enumerate(prompt):
86
+ add_alt = False
87
+ do_combine = False
88
+ if c == OPEN:
89
+ open_count += 1
90
+ if open_count == 1:
91
+ first_open = i
92
+ results = list(combine(results, [(prompt[start:i], 1)]))
93
+ start = i + 1
94
+ else:
95
+ nested = True
96
+
97
+ if c == MARK and open_count == 1:
98
+ mark = i
99
+
100
+ if c == SEPARATE and open_count == 1:
101
+ add_alt = True
102
+
103
+ if c == CLOSE:
104
+ open_count -= 1
105
+ if open_count == 0:
106
+ add_alt = True
107
+ do_combine = True
108
+ if i == len(prompt) - 1 and open_count > 0:
109
+ add_alt = True
110
+ do_combine = True
111
+
112
+ if add_alt:
113
+ end = i
114
+ weight = 1
115
+ if mark != -1:
116
+ weight_str = prompt[mark + 1:i]
117
+ try:
118
+ weight = float(weight_str)
119
+ end = mark
120
+ except ValueError:
121
+ print("warning, not a number:", weight_str)
122
+
123
+
124
+
125
+ alt = (prompt[start:end], weight)
126
+ alts += get_weighted_prompt(alt) if nested else [alt]
127
+ nested = False
128
+ mark = -1
129
+ start = i + 1
130
+
131
+ if do_combine:
132
+ if len(alts) <= 1:
133
+ alts = [(prompt[first_open:i + 1], 1)]
134
+
135
+ results = list(combine(results, alts))
136
+ alts = []
137
+
138
+ # rest of the prompt
139
+ results = list(combine(results, [(prompt[start:], 1)]))
140
+ weight_sum = sum(map(lambda r: r[1], results))
141
+ results = list(map(lambda p: (p[0], p[1] / weight_sum * full_weight), results))
142
+
143
+ return results
144
+
145
+
146
+ def switch_syntax(prompt):
147
+ p = list(prompt)
148
+ stack = []
149
+ for i, c in enumerate(p):
150
+ if c == '{' or c == '[' or c == '(':
151
+ stack.append(c)
152
+
153
+ if len(stack) > 0:
154
+ if c == '}' or c == ']' or c == ')':
155
+ stack.pop()
156
+
157
+ if c == REAL_MARK and stack[-1] == '{':
158
+ p[i] = MARK
159
+
160
+ return "".join(p)
161
+
162
+ # def test(p, w=1):
163
+ # print('')
164
+ # print(p)
165
+ # result = get_weighted_prompt((p, w))
166
+ # print(result)
167
+ # print(sum(map(lambda x: x[1], result)))
168
+ #
169
+ #
170
+ # test("fantasy landscape")
171
+ # test("fantasy {landscape|city}, dark")
172
+ # test("fantasy {landscape|city}, {fire|ice} ")
173
+ # test("fantasy {landscape|city}, {fire|ice}, {dark|light} ")
174
+ # test("fantasy landscape, {{fire|lava}|ice}")
175
+ # test("fantasy landscape, {{fire@4|lava@1}|ice@2}")
176
+ # test("fantasy landscape, {{fire@error|lava@1}|ice@2}")
177
+ # test("fantasy landscape, {{fire|lava}|ice@2")
178
+ # test("fantasy landscape, {fire|lava} {cool} {ice,water}")
179
+ # test("fantasy landscape, {fire|lava} {cool} {ice,water")
180
+ # test("{lava|ice|water@5}")
181
+ # test("{fire@4|lava@1}", 5)
182
+ # test("{{fire@4|lava@1}|ice@2|water@5}")
183
+ # test("{fire|lava@3.5}")