Akumetsu971
commited on
Commit
·
d0bd9ea
1
Parent(s):
3a509a0
Upload 3 files
Browse files- VCM07_style.pt +3 -0
- VCM07_style2.pt +3 -0
- prompt_blending.py +183 -0
VCM07_style.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8de463223c45d273ec77449808d47bce0b6987678ccc71cf3d413beba6ad3a17
|
3 |
+
size 25515
|
VCM07_style2.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2f1aec4732e93aa0943a30f1d8c8ec666abc9e94b0d0136d43c58740bd3d510f
|
3 |
+
size 25515
|
prompt_blending.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import modules.scripts as scripts
|
2 |
+
import modules.prompt_parser as prompt_parser
|
3 |
+
import itertools
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def hijacked_get_learned_conditioning(model, prompts, steps):
|
8 |
+
global real_get_learned_conditioning
|
9 |
+
|
10 |
+
if not hasattr(model, '__hacked'):
|
11 |
+
real_model_func = model.get_learned_conditioning
|
12 |
+
|
13 |
+
def hijacked_model_func(texts):
|
14 |
+
weighted_prompts = list(map(lambda t: get_weighted_prompt((t, 1)), texts))
|
15 |
+
all_texts = []
|
16 |
+
for weighted_prompt in weighted_prompts:
|
17 |
+
for (prompt, weight) in weighted_prompt:
|
18 |
+
all_texts.append(prompt)
|
19 |
+
|
20 |
+
if len(all_texts) > len(texts):
|
21 |
+
all_conds = real_model_func(all_texts)
|
22 |
+
offset = 0
|
23 |
+
|
24 |
+
conds = []
|
25 |
+
|
26 |
+
for weighted_prompt in weighted_prompts:
|
27 |
+
c = torch.zeros_like(all_conds[offset])
|
28 |
+
for (i, (prompt, weight)) in enumerate(weighted_prompt):
|
29 |
+
c = torch.add(c, all_conds[i+offset], alpha=weight)
|
30 |
+
conds.append(c)
|
31 |
+
offset += len(weighted_prompt)
|
32 |
+
return conds
|
33 |
+
else:
|
34 |
+
return real_model_func(texts)
|
35 |
+
|
36 |
+
model.get_learned_conditioning = hijacked_model_func
|
37 |
+
model.__hacked = True
|
38 |
+
|
39 |
+
switched_prompts = list(map(lambda p: switch_syntax(p), prompts))
|
40 |
+
return real_get_learned_conditioning(model, switched_prompts, steps)
|
41 |
+
|
42 |
+
|
43 |
+
real_get_learned_conditioning = hijacked_get_learned_conditioning # no really, overriden below
|
44 |
+
|
45 |
+
|
46 |
+
class Script(scripts.Script):
|
47 |
+
def title(self):
|
48 |
+
return "Prompt Blending"
|
49 |
+
|
50 |
+
def show(self, is_img2img):
|
51 |
+
global real_get_learned_conditioning
|
52 |
+
if real_get_learned_conditioning == hijacked_get_learned_conditioning:
|
53 |
+
real_get_learned_conditioning = prompt_parser.get_learned_conditioning
|
54 |
+
prompt_parser.get_learned_conditioning = hijacked_get_learned_conditioning
|
55 |
+
return False
|
56 |
+
|
57 |
+
def ui(self, is_img2img):
|
58 |
+
return []
|
59 |
+
|
60 |
+
def run(self, p, seeds):
|
61 |
+
return
|
62 |
+
|
63 |
+
|
64 |
+
OPEN = '{'
|
65 |
+
CLOSE = '}'
|
66 |
+
SEPARATE = '|'
|
67 |
+
MARK = '@'
|
68 |
+
REAL_MARK = ':'
|
69 |
+
|
70 |
+
|
71 |
+
def combine(left, right):
|
72 |
+
return map(lambda p: (p[0][0] + p[1][0], p[0][1] * p[1][1]), itertools.product(left, right))
|
73 |
+
|
74 |
+
|
75 |
+
def get_weighted_prompt(prompt_weight):
|
76 |
+
(prompt, full_weight) = prompt_weight
|
77 |
+
results = [('', full_weight)]
|
78 |
+
alts = []
|
79 |
+
start = 0
|
80 |
+
mark = -1
|
81 |
+
open_count = 0
|
82 |
+
first_open = 0
|
83 |
+
nested = False
|
84 |
+
|
85 |
+
for i, c in enumerate(prompt):
|
86 |
+
add_alt = False
|
87 |
+
do_combine = False
|
88 |
+
if c == OPEN:
|
89 |
+
open_count += 1
|
90 |
+
if open_count == 1:
|
91 |
+
first_open = i
|
92 |
+
results = list(combine(results, [(prompt[start:i], 1)]))
|
93 |
+
start = i + 1
|
94 |
+
else:
|
95 |
+
nested = True
|
96 |
+
|
97 |
+
if c == MARK and open_count == 1:
|
98 |
+
mark = i
|
99 |
+
|
100 |
+
if c == SEPARATE and open_count == 1:
|
101 |
+
add_alt = True
|
102 |
+
|
103 |
+
if c == CLOSE:
|
104 |
+
open_count -= 1
|
105 |
+
if open_count == 0:
|
106 |
+
add_alt = True
|
107 |
+
do_combine = True
|
108 |
+
if i == len(prompt) - 1 and open_count > 0:
|
109 |
+
add_alt = True
|
110 |
+
do_combine = True
|
111 |
+
|
112 |
+
if add_alt:
|
113 |
+
end = i
|
114 |
+
weight = 1
|
115 |
+
if mark != -1:
|
116 |
+
weight_str = prompt[mark + 1:i]
|
117 |
+
try:
|
118 |
+
weight = float(weight_str)
|
119 |
+
end = mark
|
120 |
+
except ValueError:
|
121 |
+
print("warning, not a number:", weight_str)
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
alt = (prompt[start:end], weight)
|
126 |
+
alts += get_weighted_prompt(alt) if nested else [alt]
|
127 |
+
nested = False
|
128 |
+
mark = -1
|
129 |
+
start = i + 1
|
130 |
+
|
131 |
+
if do_combine:
|
132 |
+
if len(alts) <= 1:
|
133 |
+
alts = [(prompt[first_open:i + 1], 1)]
|
134 |
+
|
135 |
+
results = list(combine(results, alts))
|
136 |
+
alts = []
|
137 |
+
|
138 |
+
# rest of the prompt
|
139 |
+
results = list(combine(results, [(prompt[start:], 1)]))
|
140 |
+
weight_sum = sum(map(lambda r: r[1], results))
|
141 |
+
results = list(map(lambda p: (p[0], p[1] / weight_sum * full_weight), results))
|
142 |
+
|
143 |
+
return results
|
144 |
+
|
145 |
+
|
146 |
+
def switch_syntax(prompt):
|
147 |
+
p = list(prompt)
|
148 |
+
stack = []
|
149 |
+
for i, c in enumerate(p):
|
150 |
+
if c == '{' or c == '[' or c == '(':
|
151 |
+
stack.append(c)
|
152 |
+
|
153 |
+
if len(stack) > 0:
|
154 |
+
if c == '}' or c == ']' or c == ')':
|
155 |
+
stack.pop()
|
156 |
+
|
157 |
+
if c == REAL_MARK and stack[-1] == '{':
|
158 |
+
p[i] = MARK
|
159 |
+
|
160 |
+
return "".join(p)
|
161 |
+
|
162 |
+
# def test(p, w=1):
|
163 |
+
# print('')
|
164 |
+
# print(p)
|
165 |
+
# result = get_weighted_prompt((p, w))
|
166 |
+
# print(result)
|
167 |
+
# print(sum(map(lambda x: x[1], result)))
|
168 |
+
#
|
169 |
+
#
|
170 |
+
# test("fantasy landscape")
|
171 |
+
# test("fantasy {landscape|city}, dark")
|
172 |
+
# test("fantasy {landscape|city}, {fire|ice} ")
|
173 |
+
# test("fantasy {landscape|city}, {fire|ice}, {dark|light} ")
|
174 |
+
# test("fantasy landscape, {{fire|lava}|ice}")
|
175 |
+
# test("fantasy landscape, {{fire@4|lava@1}|ice@2}")
|
176 |
+
# test("fantasy landscape, {{fire@error|lava@1}|ice@2}")
|
177 |
+
# test("fantasy landscape, {{fire|lava}|ice@2")
|
178 |
+
# test("fantasy landscape, {fire|lava} {cool} {ice,water}")
|
179 |
+
# test("fantasy landscape, {fire|lava} {cool} {ice,water")
|
180 |
+
# test("{lava|ice|water@5}")
|
181 |
+
# test("{fire@4|lava@1}", 5)
|
182 |
+
# test("{{fire@4|lava@1}|ice@2|water@5}")
|
183 |
+
# test("{fire|lava@3.5}")
|