CM2000112 / inference.py
jayparmr's picture
update : inference
35575bb verified
import os
import traceback
from typing import List, Optional
import pydash as _
import torch
from botocore.vendored.six import BytesIO
from numpy import who
import internals.util.prompt as prompt_util
from internals.data.dataAccessor import update_db, update_db_source_failed
from internals.data.task import ModelType, Task, TaskType
from internals.pipelines.commons import Img2Img, Text2Img
from internals.pipelines.controlnets import ControlNet
from internals.pipelines.high_res import HighRes
from internals.pipelines.img_classifier import ImageClassifier
from internals.pipelines.img_to_text import Image2Text
from internals.pipelines.inpainter import InPainter
from internals.pipelines.object_remove import ObjectRemoval
from internals.pipelines.prompt_modifier import PromptModifier
from internals.pipelines.realtime_draw import RealtimeDraw
from internals.pipelines.remove_background import RemoveBackgroundV3
from internals.pipelines.replace_background import ReplaceBackground
from internals.pipelines.safety_checker import SafetyChecker
from internals.pipelines.sdxl_tile_upscale import SDXLTileUpscaler
from internals.pipelines.upscaler import Upscaler
from internals.util.args import apply_style_args
from internals.util.avatar import Avatar
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda, clear_cuda_and_gc
from internals.util.commons import (
base64_to_image,
construct_default_s3_url,
download_image,
image_to_base64,
upload_image,
upload_images,
)
from internals.util.config import (
get_is_sdxl,
get_low_gpu_mem,
get_model_dir,
get_num_return_sequences,
set_configs_from_task,
set_model_config,
set_root_dir,
)
from internals.util.lora_style import LoraStyle
from internals.util.model_loader import load_model_from_config
from internals.util.slack import Slack
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
auto_mode = False
prompt_modifier = PromptModifier(num_of_sequences=get_num_return_sequences())
upscaler = Upscaler()
inpainter = InPainter()
high_res = HighRes()
img2text = Image2Text()
img_classifier = ImageClassifier()
object_removal = ObjectRemoval()
replace_background = ReplaceBackground()
remove_background_v3 = RemoveBackgroundV3()
replace_background = ReplaceBackground()
controlnet = ControlNet()
lora_style = LoraStyle()
text2img_pipe = Text2Img()
img2img_pipe = Img2Img()
safety_checker = SafetyChecker()
slack = Slack()
avatar = Avatar()
realtime_draw = RealtimeDraw()
sdxl_tileupscaler = SDXLTileUpscaler()
custom_scripts: List = []
def get_patched_prompt(task: Task):
return prompt_util.get_patched_prompt(task, avatar, lora_style, prompt_modifier)
def get_patched_prompt_text2img(task: Task):
return prompt_util.get_patched_prompt_text2img(
task, avatar, lora_style, prompt_modifier
)
def get_patched_prompt_tile_upscale(task: Task):
return prompt_util.get_patched_prompt_tile_upscale(
task, avatar, lora_style, img_classifier, img2text, is_sdxl=get_is_sdxl()
)
def get_intermediate_dimension(task: Task):
if task.get_high_res_fix():
return HighRes.get_intermediate_dimension(task.get_width(), task.get_height())
else:
return task.get_width(), task.get_height()
@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def canny(task: Task):
prompt, _ = get_patched_prompt(task)
width, height = get_intermediate_dimension(task)
controlnet.load_model("canny")
# pipe2 is used for canny and pose
lora_patcher = lora_style.get_patcher(
[controlnet.pipe2, high_res.pipe], task.get_style()
)
lora_patcher.patch()
kwargs = {
"prompt": prompt,
"imageUrl": task.get_imageUrl(),
"seed": task.get_seed(),
"num_inference_steps": task.get_steps(),
"width": width,
"height": height,
"negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(),
"apply_preprocess": task.get_apply_preprocess(),
**task.cnc_kwargs(),
**lora_patcher.kwargs(),
}
(images, has_nsfw), control_image = controlnet.process(**kwargs)
if task.get_high_res_fix():
kwargs = {
"prompt": prompt,
"negative_prompt": [task.get_negative_prompt()]
* get_num_return_sequences(),
"images": images,
"seed": task.get_seed(),
"width": task.get_width(),
"height": task.get_height(),
"num_inference_steps": task.get_steps(),
**task.high_res_kwargs(),
}
images, _ = high_res.apply(**kwargs)
upload_image(
control_image, f"crecoAI/{task.get_taskId()}_condition.png" # pyright: ignore
)
generated_image_urls = upload_images(images, "_canny", task.get_taskId())
lora_patcher.cleanup()
controlnet.cleanup()
return {
"modified_prompts": prompt,
"generated_image_urls": generated_image_urls,
"has_nsfw": has_nsfw,
}
@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def canny_img2img(task: Task):
prompt, _ = get_patched_prompt(task)
width, height = get_intermediate_dimension(task)
controlnet.load_model("canny_2x")
lora_patcher = lora_style.get_patcher(
[controlnet.pipe, high_res.pipe], task.get_style()
)
lora_patcher.patch()
kwargs = {
"prompt": prompt,
"imageUrl": task.get_imageUrl(),
"seed": task.get_seed(),
"num_inference_steps": task.get_steps(),
"width": width,
"height": height,
"negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(),
**task.cnci2i_kwargs(),
**lora_patcher.kwargs(),
}
(images, has_nsfw), control_image = controlnet.process(**kwargs)
if task.get_high_res_fix():
# we run both here normal upscaler and highres
# and show normal upscaler image as output
# but use highres image for tile upscale
kwargs = {
"prompt": prompt,
"negative_prompt": [task.get_negative_prompt()]
* get_num_return_sequences(),
"images": images,
"seed": task.get_seed(),
"width": task.get_width(),
"height": task.get_height(),
"num_inference_steps": task.get_steps(),
**task.high_res_kwargs(),
}
images, _ = high_res.apply(**kwargs)
# upload_images(images_high_res, "_canny_2x_highres", task.get_taskId())
for i, image in enumerate(images):
img = upscaler.upscale(
image=image,
width=task.get_width(),
height=task.get_height(),
face_enhance=task.get_face_enhance(),
resize_dimension=None,
)
img = Upscaler.to_pil(img)
images[i] = img.resize((task.get_width(), task.get_height()))
upload_image(
control_image, f"crecoAI/{task.get_taskId()}_condition.png" # pyright: ignore
)
generated_image_urls = upload_images(images, "_canny_2x", task.get_taskId())
lora_patcher.cleanup()
controlnet.cleanup()
return {
"modified_prompts": prompt,
"generated_image_urls": generated_image_urls,
"has_nsfw": has_nsfw,
}
@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def tile_upscale(task: Task):
output_key = "crecoAI/{}_tile_upscaler.png".format(task.get_taskId())
prompt = get_patched_prompt_tile_upscale(task)
controlnet.load_model("tile_upscaler")
lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
lora_patcher.patch()
kwargs = {
"imageUrl": task.get_imageUrl(),
"seed": task.get_seed(),
"num_inference_steps": task.get_steps(),
"negative_prompt": task.get_negative_prompt(),
"width": task.get_width(),
"height": task.get_height(),
"prompt": prompt,
"resize_dimension": task.get_resize_dimension(),
**task.cnt_kwargs(),
}
(images, has_nsfw), _ = controlnet.process(**kwargs)
lora_patcher.cleanup()
controlnet.cleanup()
generated_image_url = upload_image(images[0], output_key)
return {
"modified_prompts": prompt,
"generated_image_url": generated_image_url,
"has_nsfw": has_nsfw,
}
@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def scribble(task: Task):
prompt, _ = get_patched_prompt(task)
width, height = get_intermediate_dimension(task)
controlnet.load_model("scribble")
lora_patcher = lora_style.get_patcher(
[controlnet.pipe2, high_res.pipe], task.get_style()
)
lora_patcher.patch()
image = controlnet.preprocess_image(task.get_imageUrl(), width, height)
kwargs = {
"image": [image] * get_num_return_sequences(),
"seed": task.get_seed(),
"num_inference_steps": task.get_steps(),
"width": width,
"height": height,
"prompt": prompt,
"negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(),
"apply_preprocess": task.get_apply_preprocess(),
**task.cns_kwargs(),
}
(images, has_nsfw), condition_image = controlnet.process(**kwargs)
if task.get_high_res_fix():
kwargs = {
"prompt": prompt,
"negative_prompt": [task.get_negative_prompt()]
* get_num_return_sequences(),
"images": images,
"width": task.get_width(),
"height": task.get_height(),
"seed": task.get_seed(),
"num_inference_steps": task.get_steps(),
**task.high_res_kwargs(),
}
images, _ = high_res.apply(**kwargs)
upload_image(
condition_image, f"crecoAI/{task.get_taskId()}_condition.png" # pyright: ignore
)
generated_image_urls = upload_images(images, "_scribble", task.get_taskId())
lora_patcher.cleanup()
controlnet.cleanup()
return {
"modified_prompts": prompt,
"generated_image_urls": generated_image_urls,
"has_nsfw": has_nsfw,
}
@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def linearart(task: Task):
prompt, _ = get_patched_prompt(task)
width, height = get_intermediate_dimension(task)
controlnet.load_model("linearart")
lora_patcher = lora_style.get_patcher(
[controlnet.pipe2, high_res.pipe], task.get_style()
)
lora_patcher.patch()
kwargs = {
"imageUrl": task.get_imageUrl(),
"seed": task.get_seed(),
"num_inference_steps": task.get_steps(),
"width": width,
"height": height,
"prompt": prompt,
"negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(),
"apply_preprocess": task.get_apply_preprocess(),
**task.cnl_kwargs(),
}
(images, has_nsfw), condition_image = controlnet.process(**kwargs)
if task.get_high_res_fix():
# we run both here normal upscaler and highres
# and show normal upscaler image as output
# but use highres image for tile upscale
kwargs = {
"prompt": prompt,
"negative_prompt": [task.get_negative_prompt()]
* get_num_return_sequences(),
"images": images,
"seed": task.get_seed(),
"width": task.get_width(),
"height": task.get_height(),
"num_inference_steps": task.get_steps(),
**task.high_res_kwargs(),
}
images, _ = high_res.apply(**kwargs)
# upload_images(images_high_res, "_linearart_highres", task.get_taskId())
#
# for i, image in enumerate(images):
# img = upscaler.upscale(
# image=image,
# width=task.get_width(),
# height=task.get_height(),
# face_enhance=task.get_face_enhance(),
# resize_dimension=None,
# )
# img = Upscaler.to_pil(img)
# images[i] = img
upload_image(
condition_image, f"crecoAI/{task.get_taskId()}_condition.png" # pyright: ignore
)
generated_image_urls = upload_images(images, "_linearart", task.get_taskId())
lora_patcher.cleanup()
controlnet.cleanup()
return {
"modified_prompts": prompt,
"generated_image_urls": generated_image_urls,
"has_nsfw": has_nsfw,
}
@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
prompt, _ = get_patched_prompt(task)
width, height = get_intermediate_dimension(task)
controlnet.load_model("pose")
# pipe2 is used for canny and pose
lora_patcher = lora_style.get_patcher(
[controlnet.pipe2, high_res.pipe], task.get_style()
)
lora_patcher.patch()
if not task.get_apply_preprocess():
poses = [download_image(task.get_imageUrl()).resize((width, height))]
elif not task.get_pose_estimation():
print("Not detecting pose")
pose = download_image(task.get_imageUrl()).resize(
(task.get_width(), task.get_height())
)
poses = [pose] * get_num_return_sequences()
else:
poses = [
controlnet.detect_pose(task.get_imageUrl())
] * get_num_return_sequences()
if not get_is_sdxl():
# in normal pipeline we use depth + pose controlnet
depth = download_image(task.get_auxilary_imageUrl()).resize(
(task.get_width(), task.get_height())
)
depth = ControlNet.depth_image(depth)
images = [depth, poses[0]]
upload_image(depth, "crecoAI/{}_depth.png".format(task.get_taskId()))
scale = task.cnp_kwargs().pop("controlnet_conditioning_scale", None)
factor = task.cnp_kwargs().pop("control_guidance_end", None)
kwargs = {
"controlnet_conditioning_scale": [1.0, scale or 1.0],
"control_guidance_end": [0.5, factor or 1.0],
}
else:
images = poses[0]
kwargs = {}
kwargs = {
"prompt": prompt,
"image": images,
"seed": task.get_seed(),
"num_inference_steps": task.get_steps(),
"negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(),
"width": width,
"height": height,
**kwargs,
**task.cnp_kwargs(),
**lora_patcher.kwargs(),
}
(images, has_nsfw), _ = controlnet.process(**kwargs)
if task.get_high_res_fix():
kwargs = {
"prompt": prompt,
"negative_prompt": [task.get_negative_prompt()]
* get_num_return_sequences(),
"images": images,
"width": task.get_width(),
"height": task.get_height(),
"num_inference_steps": task.get_steps(),
"seed": task.get_seed(),
**task.high_res_kwargs(),
}
images, _ = high_res.apply(**kwargs)
upload_image(poses[0], "crecoAI/{}_condition.png".format(task.get_taskId()))
generated_image_urls = upload_images(images, s3_outkey, task.get_taskId())
lora_patcher.cleanup()
controlnet.cleanup()
return {
"modified_prompts": prompt,
"generated_image_urls": generated_image_urls,
"has_nsfw": has_nsfw,
}
@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def text2img(task: Task):
params = get_patched_prompt_text2img(task)
width, height = get_intermediate_dimension(task)
lora_patcher = lora_style.get_patcher(
[text2img_pipe.pipe, high_res.pipe], task.get_style()
)
lora_patcher.patch()
kwargs = {
"params": params,
"num_inference_steps": task.get_steps(),
"height": height,
"seed": task.get_seed(),
"width": width,
"negative_prompt": task.get_negative_prompt(),
**task.t2i_kwargs(),
**lora_patcher.kwargs(),
}
images, has_nsfw = text2img_pipe.process(**kwargs)
if task.get_high_res_fix():
kwargs = {
"prompt": params.prompt
if params.prompt
else [""] * get_num_return_sequences(),
"negative_prompt": [task.get_negative_prompt()]
* get_num_return_sequences(),
"images": images,
"width": task.get_width(),
"height": task.get_height(),
"num_inference_steps": task.get_steps(),
"seed": task.get_seed(),
**task.high_res_kwargs(),
}
images, _ = high_res.apply(**kwargs)
generated_image_urls = upload_images(images, "", task.get_taskId())
lora_patcher.cleanup()
return {
**params.__dict__,
"generated_image_urls": generated_image_urls,
"has_nsfw": has_nsfw,
}
@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def img2img(task: Task):
prompt, _ = get_patched_prompt(task)
width, height = get_intermediate_dimension(task)
if get_is_sdxl():
# we run lineart for img2img
controlnet.load_model("canny")
lora_patcher = lora_style.get_patcher(
[controlnet.pipe2, high_res.pipe], task.get_style()
)
lora_patcher.patch()
kwargs = {
"imageUrl": task.get_imageUrl(),
"seed": task.get_seed(),
"num_inference_steps": task.get_steps(),
"width": width,
"height": height,
"prompt": prompt,
"negative_prompt": [task.get_negative_prompt()]
* get_num_return_sequences(),
"controlnet_conditioning_scale": 0.5,
# "adapter_conditioning_scale": 0.3,
**task.i2i_kwargs(),
}
(images, has_nsfw), _ = controlnet.process(**kwargs)
else:
lora_patcher = lora_style.get_patcher(
[img2img_pipe.pipe, high_res.pipe], task.get_style()
)
lora_patcher.patch()
kwargs = {
"prompt": prompt,
"imageUrl": task.get_imageUrl(),
"negative_prompt": [task.get_negative_prompt()]
* get_num_return_sequences(),
"num_inference_steps": task.get_steps(),
"width": width,
"height": height,
"seed": task.get_seed(),
**task.i2i_kwargs(),
**lora_patcher.kwargs(),
}
images, has_nsfw = img2img_pipe.process(**kwargs)
if task.get_high_res_fix():
kwargs = {
"prompt": prompt,
"negative_prompt": [task.get_negative_prompt()]
* get_num_return_sequences(),
"images": images,
"width": task.get_width(),
"height": task.get_height(),
"num_inference_steps": task.get_steps(),
"seed": task.get_seed(),
**task.high_res_kwargs(),
}
images, _ = high_res.apply(**kwargs)
generated_image_urls = upload_images(images, "_imgtoimg", task.get_taskId())
lora_patcher.cleanup()
return {
"modified_prompts": prompt,
"generated_image_urls": generated_image_urls,
"has_nsfw": has_nsfw,
}
@update_db
@slack.auto_send_alert
def inpaint(task: Task):
if task.get_type() == TaskType.OUTPAINT:
key = "_outpaint"
prompt = [img2text.process(task.get_imageUrl())] * get_num_return_sequences()
else:
key = "_inpaint"
prompt, _ = get_patched_prompt(task)
print({"prompts": prompt})
kwargs = {
"prompt": prompt,
"image_url": task.get_imageUrl(),
"mask_image_url": task.get_maskImageUrl(),
"width": task.get_width(),
"height": task.get_height(),
"seed": task.get_seed(),
"negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(),
"num_inference_steps": task.get_steps(),
**task.ip_kwargs(),
}
images, mask = inpainter.process(**kwargs)
upload_image(mask, "crecoAI/{}_mask.png".format(task.get_taskId()))
generated_image_urls = upload_images(images, key, task.get_taskId())
clear_cuda_and_gc()
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
@update_db
@slack.auto_send_alert
def replace_bg(task: Task):
prompt = task.get_prompt()
if task.is_prompt_engineering():
prompt = prompt_modifier.modify(prompt)
else:
prompt = [prompt] * get_num_return_sequences()
lora_patcher = lora_style.get_patcher(replace_background.pipe, task.get_style())
lora_patcher.patch()
images, has_nsfw = replace_background.replace(
image=task.get_imageUrl(),
prompt=prompt,
negative_prompt=[task.get_negative_prompt()] * get_num_return_sequences(),
seed=task.get_seed(),
width=task.get_width(),
height=task.get_height(),
steps=task.get_steps(),
apply_high_res=task.get_high_res_fix(),
conditioning_scale=task.rbg_controlnet_conditioning_scale(),
model_type=task.get_modelType(),
)
generated_image_urls = upload_images(images, "_replace_bg", task.get_taskId())
lora_patcher.cleanup()
clear_cuda_and_gc()
return {
"modified_prompts": prompt,
"generated_image_urls": generated_image_urls,
"has_nsfw": has_nsfw,
}
@update_db
@slack.auto_send_alert
def remove_bg(task: Task):
output_image = remove_background_v3.remove(task.get_imageUrl())
output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId())
image_url = upload_image(output_image, output_key)
return {"generated_image_url": image_url}
@update_db
@slack.auto_send_alert
def upscale_image(task: Task):
output_key = "crecoAI/{}_upscale.png".format(task.get_taskId())
out_img = None
if (
task.get_modelType() == ModelType.ANIME
or task.get_modelType() == ModelType.COMIC
):
print("Using Anime model")
out_img = upscaler.upscale_anime(
image=task.get_imageUrl(),
width=task.get_width(),
height=task.get_height(),
face_enhance=task.get_face_enhance(),
resize_dimension=task.get_resize_dimension(),
)
else:
print("Using Real model")
out_img = upscaler.upscale(
image=task.get_imageUrl(),
width=task.get_width(),
height=task.get_height(),
face_enhance=task.get_face_enhance(),
resize_dimension=task.get_resize_dimension(),
)
image_url = upload_image(BytesIO(out_img), output_key)
clear_cuda_and_gc()
return {"generated_image_url": image_url}
@update_db
@slack.auto_send_alert
def remove_object(task: Task):
output_key = "crecoAI/{}_object_remove.png".format(task.get_taskId())
images = object_removal.process(
image_url=task.get_imageUrl(),
mask_image_url=task.get_maskImageUrl(),
seed=task.get_seed(),
width=task.get_width(),
height=task.get_height(),
)
generated_image_urls = upload_image(images[0], output_key)
clear_cuda()
return {"generated_image_urls": generated_image_urls}
def rt_draw_seg(task: Task):
image = task.get_imageUrl()
if image.startswith("http"):
image = download_image(image)
else: # consider image as base64
image = base64_to_image(image)
img = realtime_draw.process_seg(
image=image,
prompt=task.get_prompt(),
negative_prompt=task.get_negative_prompt(),
seed=task.get_seed(),
)
clear_cuda_and_gc()
base64_image = image_to_base64(img)
return {"image": base64_image}
def rt_draw_img(task: Task):
image = task.get_imageUrl()
aux_image = task.get_auxilary_imageUrl()
if image:
if image.startswith("http"):
image = download_image(image)
else: # consider image as base64
image = base64_to_image(image)
if aux_image:
if aux_image.startswith("http"):
aux_image = download_image(aux_image)
else: # consider image as base64
aux_image = base64_to_image(aux_image)
img = realtime_draw.process_img(
image=image, # pyright: ignore
image2=aux_image, # pyright: ignore
prompt=task.get_prompt(),
negative_prompt=task.get_negative_prompt(),
seed=task.get_seed(),
)
clear_cuda_and_gc()
base64_image = image_to_base64(img)
return {"image": base64_image}
@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def depth_rig(task: Task):
# Note : This task is for only processing a hardcoded character rig model using depth controlnet
# Hack : This model requires hardcoded depth images for optimal processing, so we pass it by default
default_depth_url = "https://s3.ap-south-1.amazonaws.com/assets.autodraft.in/character-sheet/rigs/character-rig-depth-map.png"
params = get_patched_prompt_text2img(task)
width, height = get_intermediate_dimension(task)
controlnet.load_model("depth")
lora_patcher = lora_style.get_patcher(
[controlnet.pipe2, high_res.pipe], task.get_style()
)
lora_patcher.patch()
kwargs = {
"params": params,
"prompt": params.prompt,
"num_inference_steps": task.get_steps(),
"imageUrl": default_depth_url,
"height": height,
"seed": task.get_seed(),
"width": width,
"negative_prompt": task.get_negative_prompt(),
**task.t2i_kwargs(),
**lora_patcher.kwargs(),
}
(images, has_nsfw), condition_image = controlnet.process(**kwargs)
if task.get_high_res_fix():
kwargs = {
"prompt": params.prompt
if params.prompt
else [""] * get_num_return_sequences(),
"negative_prompt": [task.get_negative_prompt()]
* get_num_return_sequences(),
"images": images,
"width": task.get_width(),
"height": task.get_height(),
"num_inference_steps": task.get_steps(),
"seed": task.get_seed(),
**task.high_res_kwargs(),
}
images, _ = high_res.apply(**kwargs)
upload_image(condition_image, "crecoAI/{}_condition.png".format(task.get_taskId()))
generated_image_urls = upload_images(images, "", task.get_taskId())
lora_patcher.cleanup()
return {
**params.__dict__,
"generated_image_urls": generated_image_urls,
"has_nsfw": has_nsfw,
}
def custom_action(task: Task):
from external.scripts import __scripts__
global custom_scripts
kwargs = {
"CONTROLNET": controlnet,
"LORASTYLE": lora_style,
}
torch.manual_seed(task.get_seed())
for script in __scripts__:
script = script.Script(**kwargs)
existing_script = _.find(
custom_scripts, lambda x: x.__name__ == script.__name__
)
if existing_script:
script = existing_script
else:
custom_scripts.append(script)
data = task.get_action_data()
if data["name"] == script.__name__:
return script(task, data)
def load_model_by_task(task_type: TaskType, model_id=-1):
from internals.pipelines.controlnets import clear_networks
# pre-cleanup inpaint and controlnet models
if task_type == TaskType.INPAINT or task_type == TaskType.OUTPAINT:
clear_networks()
else:
inpainter.unload()
if not text2img_pipe.is_loaded():
text2img_pipe.load(get_model_dir())
img2img_pipe.create(text2img_pipe)
high_res.load(img2img_pipe)
inpainter.init(text2img_pipe)
controlnet.init(text2img_pipe)
if task_type == TaskType.INPAINT or task_type == TaskType.OUTPAINT:
inpainter.load()
safety_checker.apply(inpainter)
elif task_type == TaskType.REPLACE_BG:
replace_background.load(
upscaler=upscaler, base=text2img_pipe, high_res=high_res
)
elif task_type == TaskType.RT_DRAW_SEG or task_type == TaskType.RT_DRAW_IMG:
realtime_draw.load(text2img_pipe)
elif task_type == TaskType.OBJECT_REMOVAL:
object_removal.load(get_model_dir())
elif task_type == TaskType.UPSCALE_IMAGE:
upscaler.load()
else:
if task_type == TaskType.TILE_UPSCALE:
# if get_is_sdxl():
# sdxl_tileupscaler.create(high_res, text2img_pipe, model_id)
# else:
controlnet.load_model("tile_upscaler")
elif task_type == TaskType.CANNY:
controlnet.load_model("canny")
elif task_type == TaskType.CANNY_IMG2IMG:
controlnet.load_model("canny_2x")
elif task_type == TaskType.SCRIBBLE:
controlnet.load_model("scribble")
elif task_type == TaskType.LINEARART:
controlnet.load_model("linearart")
elif task_type == TaskType.POSE:
controlnet.load_model("pose")
def unload_model_by_task(task_type: TaskType):
if task_type == TaskType.INPAINT or task_type == TaskType.OUTPAINT:
# inpainter.unload()
pass
elif task_type == TaskType.REPLACE_BG:
replace_background.unload()
elif task_type == TaskType.OBJECT_REMOVAL:
object_removal.unload()
elif task_type == TaskType.TILE_UPSCALE:
# if get_is_sdxl():
# sdxl_tileupscaler.unload()
# else:
controlnet.unload()
elif (
task_type == TaskType.CANNY
or task_type == TaskType.CANNY_IMG2IMG
or task_type == TaskType.SCRIBBLE
or task_type == TaskType.LINEARART
or task_type == TaskType.POSE
):
controlnet.unload()
def apply_safety_checkers():
safety_checker.apply(text2img_pipe)
safety_checker.apply(img2img_pipe)
safety_checker.apply(controlnet)
def model_fn(model_dir):
print("Logs: model loaded .... starts")
config = load_model_from_config(model_dir)
set_model_config(config)
set_root_dir(__file__)
avatar.load_local(model_dir)
lora_style.load(model_dir)
load_model_by_task(TaskType.TEXT_TO_IMAGE)
print("Logs: model loaded ....")
return
def auto_unload_task(func):
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
if get_low_gpu_mem():
task = Task(args[0])
unload_model_by_task(task.get_type()) # pyright: ignore
return result
return wrapper
@auto_unload_task
def predict_fn(data, pipe):
task = Task(data)
print("task is ", data)
clear_cuda_and_gc()
try:
task_type = task.get_type()
# Set set_environment
set_configs_from_task(task)
# Load model based on task
load_model_by_task(
task.get_type() or TaskType.TEXT_TO_IMAGE, task.get_model_id()
)
# Apply safety checkers
apply_safety_checkers()
# Realtime generation apis
if task_type == TaskType.RT_DRAW_SEG:
return rt_draw_seg(task)
if task_type == TaskType.RT_DRAW_IMG:
return rt_draw_img(task)
# Apply arguments
apply_style_args(data)
# Re-fetch styles
lora_style.fetch_styles()
# Fetch avatars
avatar.fetch_from_network(task.get_model_id())
if task_type == TaskType.TEXT_TO_IMAGE:
# Hack : Character Rigging Model Task Redirection
if task.get_model_id() == 2000336 or task.get_model_id() == 2000341:
return depth_rig(task)
return text2img(task)
elif task_type == TaskType.IMAGE_TO_IMAGE:
return img2img(task)
elif task_type == TaskType.CANNY:
return canny(task)
elif task_type == TaskType.CANNY_IMG2IMG:
return canny_img2img(task)
elif task_type == TaskType.POSE:
return pose(task)
elif task_type == TaskType.TILE_UPSCALE:
return tile_upscale(task)
elif task_type == TaskType.INPAINT:
return inpaint(task)
elif task_type == TaskType.OUTPAINT:
return inpaint(task)
elif task_type == TaskType.SCRIBBLE:
return scribble(task)
elif task_type == TaskType.LINEARART:
return linearart(task)
elif task_type == TaskType.REPLACE_BG:
return replace_bg(task)
elif task_type == TaskType.CUSTOM_ACTION:
return custom_action(task)
elif task_type == TaskType.REMOVE_BG:
return remove_bg(task)
elif task_type == TaskType.UPSCALE_IMAGE:
return upscale_image(task)
elif task_type == TaskType.OBJECT_REMOVAL:
return remove_object(task)
elif task_type == TaskType.SYSTEM_CMD:
os.system(task.get_prompt())
elif task_type == TaskType.PRELOAD_MODEL:
try:
task_type = TaskType(task.get_prompt())
except:
task_type = TaskType.SYSTEM_CMD
load_model_by_task(task_type)
else:
raise Exception("Invalid task type")
except Exception as e:
slack.error_alert(task, e)
controlnet.cleanup()
traceback.print_exc()
update_db_source_failed(task.get_sourceId(), task.get_userId())
return None