visual-chatgpt-zh-vits / modules /controlnet_scibble.py
FrankZxShen's picture
Upload 55 files
aa69275
raw
history blame contribute delete
No virus
3.26 kB
from modules.utils import *
class Image2Scribble:
def __init__(self, device, pretrained_model_dir):
print("Initializing Image2Scribble")
self.detector = HEDdetector.from_pretrained(f'{pretrained_model_dir}/ControlNet')
@prompts(name="Sketch Detection On Image",
description="useful when you want to generate a scribble of the image. "
"like: generate a scribble of this image, or generate a sketch from this image, "
"detect the sketch from this image. "
"The input to this tool should be a string, representing the image_path")
def inference(self, inputs):
image = Image.open(inputs)
scribble = self.detector(image, scribble=True)
updated_image_path = get_new_image_name(inputs, func_name="scribble")
scribble.save(updated_image_path)
print(f"\nProcessed Image2Scribble, Input Image: {inputs}, Output Scribble: {updated_image_path}")
return updated_image_path
class ScribbleText2Image:
def __init__(self, device, pretrained_model_dir):
print("Initializing ScribbleText2Image to %s" % device)
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.controlnet = ControlNetModel.from_pretrained(f"{pretrained_model_dir}/sd-controlnet-scribble",
torch_dtype=self.torch_dtype)
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
f"{pretrained_model_dir}/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
torch_dtype=self.torch_dtype
)
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
self.pipe.to(device)
self.seed = -1
self.a_prompt = 'best quality, extremely detailed'
self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
'fewer digits, cropped, worst quality, low quality'
@prompts(name="Generate Image Condition On Sketch Image",
description="useful when you want to generate a new real image from both the user desciption and "
"a scribble image or a sketch image. "
"The input to this tool should be a comma seperated string of two, "
"representing the image_path and the user description")
def inference(self, inputs):
image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
image = Image.open(image_path)
self.seed = random.randint(0, 65535)
seed_everything(self.seed)
prompt = instruct_text + ', ' + self.a_prompt
image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
guidance_scale=9.0).images[0]
updated_image_path = get_new_image_name(image_path, func_name="scribble2image")
image.save(updated_image_path)
print(f"\nProcessed ScribbleText2Image, Input Scribble: {image_path}, Input Text: {instruct_text}, "
f"Output Image: {updated_image_path}")
return updated_image_path