fixing prompt code
Browse files- S2I/commons/controller.py +2 -2
- S2I/modules/sketch2image.py +2 -1
S2I/commons/controller.py
CHANGED
@@ -56,7 +56,7 @@ class Sketch2ImageController():
|
|
56 |
|
57 |
def artwork(self, options, image, prompt, prompt_template, style_name, seed, val_r, faster, model_name, type_flag, prompt_quality):
|
58 |
self.load_pipeline(zero_options=options)
|
59 |
-
prompt = prompt_template.replace("{prompt}", prompt)
|
60 |
|
61 |
# if type_flag == 'live-sketch':
|
62 |
# img = Image.fromarray(np.array(image["composite"])[:, :, -1])
|
@@ -79,7 +79,7 @@ class Sketch2ImageController():
|
|
79 |
noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
|
80 |
|
81 |
with torch.no_grad():
|
82 |
-
output_image = self.pipe.generate(c_t, prompt, prompt_quality, r=val_r, noise_map=noise, half_model=faster, model_name=model_name)
|
83 |
|
84 |
output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
|
85 |
|
|
|
56 |
|
57 |
def artwork(self, options, image, prompt, prompt_template, style_name, seed, val_r, faster, model_name, type_flag, prompt_quality):
|
58 |
self.load_pipeline(zero_options=options)
|
59 |
+
# prompt = prompt_template.replace("{prompt}", prompt)
|
60 |
|
61 |
# if type_flag == 'live-sketch':
|
62 |
# img = Image.fromarray(np.array(image["composite"])[:, :, -1])
|
|
|
79 |
noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
|
80 |
|
81 |
with torch.no_grad():
|
82 |
+
output_image = self.pipe.generate(c_t, prompt, prompt_quality, prompt_template, r=val_r, noise_map=noise, half_model=faster, model_name=model_name)
|
83 |
|
84 |
output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
|
85 |
|
S2I/modules/sketch2image.py
CHANGED
@@ -13,9 +13,10 @@ class Sketch2ImagePipeline(PrimaryModel):
|
|
13 |
super().__init__()
|
14 |
self.timestep = torch.tensor([999], device="cuda").long()
|
15 |
|
16 |
-
def generate(self, c_t, prompt=None, prompt_quality=None, prompt_tokens=None, r=1.0, noise_map=None, half_model=None, model_name=None):
|
17 |
self.from_pretrained(model_name=model_name, r=r)
|
18 |
prompt_enhanced = self.automatic_enhance_prompt(prompt, prompt_quality)
|
|
|
19 |
assert (prompt is None) != (prompt_tokens is None), "Either prompt or prompt_tokens should be provided"
|
20 |
|
21 |
if half_model == 'float16':
|
|
|
13 |
super().__init__()
|
14 |
self.timestep = torch.tensor([999], device="cuda").long()
|
15 |
|
16 |
+
def generate(self, c_t, prompt=None, prompt_quality=None, prompt_template=None, prompt_tokens=None, r=1.0, noise_map=None, half_model=None, model_name=None):
|
17 |
self.from_pretrained(model_name=model_name, r=r)
|
18 |
prompt_enhanced = self.automatic_enhance_prompt(prompt, prompt_quality)
|
19 |
+
prompt_enhanced = prompt_template.replace("{prompt}", prompt_enhanced)
|
20 |
assert (prompt is None) != (prompt_tokens is None), "Either prompt or prompt_tokens should be provided"
|
21 |
|
22 |
if half_model == 'float16':
|