import os import random import numpy as np import gradio as gr import base64 from io import BytesIO import PIL.Image from typing import Tuple from novita_client import NovitaClient, V3TaskResponseStatus from time import time from style_template import styles # global variable MAX_SEED = np.iinfo(np.int32).max STYLE_NAMES = list(styles.keys()) DEFAULT_STYLE_NAME = 'Watercolor' DEFAULT_MODEL_NAME = 'sdxlUnstableDiffusers_v8HEAVENSWRATH_133813' enable_lcm_arg = False # Path to InstantID models face_adapter = f'./checkpoints/ip-adapter.bin' controlnet_path = f'./checkpoints/ControlNetModel' # controlnet-pose/canny/depth controlnet_pose_model = 'thibaud/controlnet-openpose-sdxl-1.0' controlnet_canny_model = 'diffusers/controlnet-canny-sdxl-1.0' controlnet_depth_model = 'diffusers/controlnet-depth-sdxl-1.0-small' SDXL_MODELS = [ "albedobaseXL_v04_130099", "altxl_v60_146691", "animagineXLV3_v30_231047", "animeArtDiffusionXL_alpha2_91872", "animeArtDiffusionXL_alpha3_93120", "animeIllustDiffusion_v04_117809", "breakdomainxl_V05g_124265", "brixlAMustInYour_v40Dagobah_145992", "cinemaxAlphaSDXLCinema_alpha1_107473", "cineroXLPhotomatic_v12aPHENO_137703", "clearhungAnimeXL_v10_117716", "copaxTimelessxlSDXL1_colorfulV2_100729", "counterfeitxl_v10_108721", "counterfeitxl__98184", "crystalClearXL_ccxl_97637", "dreamshaperXL09Alpha_alpha2Xl10_91562", "dynavisionXLAllInOneStylized_alpha036FP16Bakedvae_99980", "dynavisionXLAllInOneStylized_beta0411Bakedvae_109970", "dynavisionXLAllInOneStylized_release0534bakedvae_129001", "fenrisxl_145_134980", "foddaxlPhotorealism_v45_122788", "formulaxl_v10_104889", "juggernautXL_v8Rundiffusion_227002", "juggernautXL_version2_113240", "juggernautXL_version5_126522", "kohakuXL_alpha7_111843", "LahMysteriousSDXL_v40_122478", "leosamsHelloworldSDXLModel_helloworldSDXL10_112178", "leosamsHelloworldSDXL_helloworldSDXL50_268813", "mbbxlUltimate_v10RC_94686", "moefusionSDXL_v10_114018", "nightvisionXLPhotorealisticPortrait_beta0681Bakedvae_108833", "nightvisionXLPhotorealisticPortrait_beta0702Bakedvae_113098", "nightvisionXLPhotorealisticPortrait_release0770Bakedvae_154525", "novaPrimeXL_v10_107899", "pixelwave_v10_117722", "protovisionXLHighFidelity3D_beta0520Bakedvae_106612", "protovisionXLHighFidelity3D_release0620Bakedvae_131308", "protovisionXLHighFidelity3D_release0630Bakedvae_154359", "protovisionXLHighFidelity3D_releaseV660Bakedvae_207131", "realismEngineSDXL_v05b_131513", "realismEngineSDXL_v10_136287", "realisticStockPhoto_v10_115618", "RealitiesEdgeXL_4_122673", "realvisxlV20_v20Bakedvae_129156", "riotDiffusionXL_v20_139293", "roxl_v10_109354", "sdxlNijiSpecial_sdxlNijiSE_115638", "sdxlNijiV3_sdxlNijiV3_104571", "sdxlNijiV51_sdxlNijiV51_112807", "sdxlUnstableDiffusers_v8HEAVENSWRATH_133813", "sdxlYamersAnimeUltra_yamersAnimeV3_121537", "sd_xl_base_0.9", "sd_xl_base_1.0", "shikianimexl_v10_93788", "theTalosProject_v10_117893", "thinkdiffusionxl_v10_145931", "voidnoisecorexl_r1486_150780", "wlopArienwlopstylexl_v10_101973", "wlopSTYLEXL_v2_126171", "xl13AsmodeusSFWNSFW_v22BakedVAE_111954", "xxmix9realisticsdxl_v10_123235", "zavychromaxl_b2_103298", ] LORA_MODELS = [ "DI_belle_delphine_sdxl_v1_93586", #"NsfwPovAllInOneLoraSdxl-000009MINI_120545", "NsfwPovAllInOneLoraSdxl-000009_120561", "acidzlime-sdxl_154149", "add-detail-xl_99264", "bwporcelaincd_xl-000007_124344", "concept_pov_dt_xl2-000020_119643", "epoxy_skull-sdxl_153213", "landscape-painting-sdxl_v2_111037", "polyhedron_all_sdxl-000004_110557", "ral-beer-sdxl_235173", "ral-wtchz-sdxl_233487", "sdxl_cute_social_comic-000002_107980", "sdxl_glass_136034", "sdxl_lightning_8step_lora_290441", "sdxl_offset_example_v10_113006", "sdxl_wrong_lora", "xl_more_art-full_v1_113467", "xl_yoshiaki_kawajiri_v1r64_126468", ] CONTROLNET_DICT = dict( pose={ 'model_name': 'controlnet-openpose-sdxl-1.0', 'strength': 1, 'preprocessor': 'openpose', }, depth={ 'model_name': 'controlnet-depth-sdxl-1.0', 'strength': 1, 'preprocessor': 'depth', }, canny={ 'model_name': 'controlnet-canny-sdxl-1.0', 'strength': 1, 'preprocessor': 'canny', }, lineart={ 'model_name': 'controlnet-softedge-dexined-sdxl-1.0', 'strength': 1, 'preprocessor': 'lineart', }, ) last_check = 0 def get_novita_client (novita_key): client = NovitaClient(novita_key, os.getenv('NOVITA_API_URI', None)) return client get_local_storage = ''' function () { globalThis.setStorage = (key, value)=>{ localStorage.setItem(key, JSON.stringify(value)) } globalThis.getStorage = (key, value)=>{ return JSON.parse(localStorage.getItem(key)) } const novita_key = getStorage("novita_key") return [novita_key]; } ''' def toggle_lcm_ui (value): if value: return ( gr.update(minimum=0, maximum=100, step=1, value=5), gr.update(minimum=0.1, maximum=20.0, step=0.1, value=1.5), ) else: return ( gr.update(minimum=5, maximum=100, step=1, value=30), gr.update(minimum=0.1, maximum=20.0, step=0.1, value=5), ) def randomize_seed_fn (seed: int, randomize_seed: bool) -> int: if randomize_seed: seed = random.randint(0, MAX_SEED) return seed def remove_tips (): return gr.update(visible=False) def apply_style (style_name: str, positive: str, negative: str = "") -> Tuple[str, str]: p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME]) return p.replace("{prompt}", positive), n + " " + negative def get_example (): case = [ [ './examples/yann-lecun_resize.jpg', None, 'a man', 'Spring Festival', '(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green', ], [ './examples/musk_resize.jpeg', './examples/poses/pose2.jpg', 'a man flying in the sky in Mars', 'Mars', '(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green', ], [ './examples/sam_resize.png', './examples/poses/pose4.jpg', 'a man doing a silly pose wearing a suite', 'Jungle', '(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, gree', ], [ './examples/schmidhuber_resize.png', './examples/poses/pose3.jpg', 'a man sit on a chair', 'Neon', '(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green', ], [ './examples/kaifu_resize.png', './examples/poses/pose.jpg', 'a man', 'Vibrant Color', '(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green', ], ] return case def load_example (face_file, pose_file, prompt, style, negative_prompt): name = os.path.basename(face_file).split('_')[0] image = PIL.Image.open(open(f'./examples/generated/{name}.jpg', 'rb')) return image, gr.update(visible=True) upload_depot = {} def upload_assets_with_cache (client, paths): global upload_depot pending_paths = [path for path in paths if not path in upload_depot] if pending_paths: print('uploading images:', pending_paths) for key, value in zip(pending_paths, client.upload_assets(pending_paths)): upload_depot[key] = value return [upload_depot[path] for path in paths] def generate_image ( novita_key1, model_name, lora_selection, face_image_path, pose_image_path, prompt, negative_prompt, style_name, num_steps, identitynet_strength_ratio, adapter_strength_ratio, controlnet_strengths, controlnet_selection, guidance_scale, seed, scheduler, #enable_LCM, #enhance_face_region, progress=gr.Progress(track_tqdm=True), ): if face_image_path is None: raise gr.Error(f'Cannot find any input face image! Please refer to step 1️⃣') #print('novita_key:', novita_key1) #print('face_image_path:', face_image_path) if not novita_key1: raise gr.Error(f'Please input your Novita Key!') try: client = get_novita_client(novita_key1) prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt) prompt = prompt[:1024] #print('prompt:', prompt) #print('negative_prompt:', negative_prompt) #print('seed:', seed) #print('identitynet_strength_ratio:', identitynet_strength_ratio) #print('adapter_strength_ratio:', adapter_strength_ratio) #print('scheduler:', scheduler) #print('guidance_scale:', guidance_scale) #print('num_steps:', num_steps) ref_image_path = pose_image_path if pose_image_path else face_image_path ref_image = PIL.Image.open(ref_image_path) width, height = ref_image.size large_edge = max(width, height) if large_edge < 1024: scaling = 1024 / large_edge width = int(width * scaling) height = int(height * scaling) ( CONTROLNET_DICT['pose']['strength'], CONTROLNET_DICT['canny']['strength'], CONTROLNET_DICT['depth']['strength'], CONTROLNET_DICT['lineart']['strength'], ) = controlnet_strengths face_image_uploaded, ref_image_uploaded = upload_assets_with_cache(client, [face_image_path, ref_image_path]) res = client._post('/v3/async/instant-id', { 'extra': { 'response_image_type': 'jpeg', }, 'model_name': f'{model_name}.safetensors', 'face_image_assets_ids': [face_image_uploaded], 'ref_image_assets_ids': [ref_image_uploaded], 'prompt': prompt, 'negative_prompt': negative_prompt, 'controlnet': { 'units': [CONTROLNET_DICT[name] for name in controlnet_selection if name in CONTROLNET_DICT], }, 'loras': [dict( model_name=f'{name}.safetensors', scale=1, ) for name in lora_selection], 'image_num': 1, 'steps': num_steps, 'seed': seed, 'guidance_scale': guidance_scale, 'sampler_name': scheduler, 'id_strength': identitynet_strength_ratio, 'adapter_strength': adapter_strength_ratio, 'width': width, 'height': height, }) print('task_id:', res['task_id']) def progress (x): global last_check t = time() if t > last_check + 5: last_check = t print('progress:', t, x.task.status) final_res = client.wait_for_task_v3(res['task_id'], callback=progress) if final_res is None or final_res.task.status == V3TaskResponseStatus.TASK_STATUS_FAILED: raise RuntimeError(f'Novita task failed: {final_res and final_res.task.status}') print('status:', final_res.task.status) print('returned images:', final_res.images) final_res.download_images() except Exception as e: raise gr.Error(f'Error: {e}') #print('final_res:', final_res) #print('final_res.images_encoded:', final_res.images_encoded) image = PIL.Image.open(BytesIO(base64.b64decode(final_res.images_encoded[0]))) return image, gr.update(visible=True) # Description title = r'''