Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import spaces | |
from diffusers import DiffusionPipeline | |
from pathlib import Path | |
import gc | |
import subprocess | |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
subprocess.run('pip cache purge', shell=True) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch.set_grad_enabled(False) | |
models = [ | |
"camenduru/FLUX.1-dev-diffusers", | |
"black-forest-labs/FLUX.1-schnell", | |
"sayakpaul/FLUX.1-merged", | |
"John6666/blue-pencil-flux1-v001-fp8-flux", | |
"John6666/copycat-flux-test-fp8-v11-fp8-flux", | |
"John6666/nepotism-fuxdevschnell-v3aio-fp8-flux", | |
"John6666/niji-style-flux-devfp8-fp8-flux", | |
"John6666/fluxunchained-artfulnsfw-fut516xfp8e4m3fnv11-fp8-flux", | |
"John6666/fastflux-unchained-t5f16-fp8-flux", | |
"John6666/the-araminta-flux1a1-fp8-flux", | |
"John6666/acorn-is-spinning-flux-v11-fp8-flux", | |
"John6666/fluxescore-dev-v10fp16-fp8-flux", | |
# "", | |
] | |
num_loras = 3 | |
def is_repo_name(s): | |
import re | |
return re.fullmatch(r'^[^/,\s\"\']+/[^/,\s\"\']+$', s) | |
def is_repo_exists(repo_id): | |
from huggingface_hub import HfApi | |
api = HfApi() | |
try: | |
if api.repo_exists(repo_id=repo_id): return True | |
else: return False | |
except Exception as e: | |
print(f"Error: Failed to connect {repo_id}. ") | |
print(e) | |
return True # for safe | |
def clear_cache(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
def get_repo_safetensors(repo_id: str): | |
from huggingface_hub import HfApi | |
api = HfApi() | |
try: | |
if not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(value="", choices=[]) | |
files = api.list_repo_files(repo_id=repo_id) | |
except Exception as e: | |
print(f"Error: Failed to get {repo_id}'s info.") | |
print(e) | |
return gr.update(choices=[]) | |
files = [f for f in files if f.endswith(".safetensors")] | |
if len(files) == 0: return gr.update(value="", choices=[]) | |
else: return gr.update(value=files[0], choices=files) | |
# Initialize the base model | |
base_model = models[0] | |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16) | |
last_model = models[0] | |
def change_base_model(repo_id: str, progress=gr.Progress(track_tqdm=True)): | |
global pipe | |
global last_model | |
try: | |
if repo_id == last_model or not is_repo_name(repo_id) or not is_repo_exists(repo_id): return | |
progress(0, desc=f"Loading model: {repo_id}") | |
clear_cache() | |
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16) | |
last_model = repo_id | |
progress(1, desc=f"Model loaded: {repo_id}") | |
except Exception as e: | |
print(e) | |
return gr.update(visible=True) | |
def compose_lora_json(lorajson: list[dict], i: int, name: str, scale: float, filename: str, trigger: str): | |
lorajson[i]["name"] = str(name) if name != "None" else "" | |
lorajson[i]["scale"] = float(scale) | |
lorajson[i]["filename"] = str(filename) | |
lorajson[i]["trigger"] = str(trigger) | |
return lorajson | |
def is_valid_lora(lorajson: list[dict]): | |
valid = False | |
for d in lorajson: | |
if "name" in d.keys() and d["name"] and d["name"] != "None": valid = True | |
return valid | |
def get_trigger_word(lorajson: list[dict]): | |
trigger = "" | |
for d in lorajson: | |
if "name" in d.keys() and d["name"] and d["name"] != "None" and d["trigger"]: | |
trigger += ", " + d["trigger"] | |
return trigger | |
# https://huggingface.co/docs/diffusers/v0.23.1/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora | |
# https://github.com/huggingface/diffusers/issues/4919 | |
def fuse_loras(pipe, lorajson: list[dict]): | |
if not lorajson or not isinstance(lorajson, list): return | |
a_list = [] | |
w_list = [] | |
for d in lorajson: | |
if not d or not isinstance(d, dict) or not d["name"] or d["name"] == "None": continue | |
k = d["name"] | |
if is_repo_name(k) and is_repo_exists(k): | |
a_name = Path(k).stem | |
pipe.load_lora_weights(k, weight_name=d["filename"], adapter_name = a_name) | |
elif not Path(k).exists(): | |
print(f"LoRA not found: {k}") | |
continue | |
else: | |
w_name = Path(k).name | |
a_name = Path(k).stem | |
pipe.load_lora_weights(k, weight_name = w_name, adapter_name = a_name) | |
a_list.append(a_name) | |
w_list.append(d["scale"]) | |
if not a_list: return | |
pipe.set_adapters(a_list, adapter_weights=w_list) | |
pipe.fuse_lora(adapter_names=a_list, lora_scale=1.0) | |
#pipe.unload_lora_weights() | |
def description_ui(): | |
gr.Markdown( | |
""" | |
- Mod of [multimodalart/flux-lora-the-explorer](https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer), | |
[gokaygokay/FLUX-Prompt-Generator](https://huggingface.co/spaces/gokaygokay/FLUX-Prompt-Generator). | |
""" | |
) | |
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM | |
def load_prompt_enhancer(): | |
try: | |
model_checkpoint = "gokaygokay/Flux-Prompt-Enhance" | |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).eval().to(device=device) | |
enhancer_flux = pipeline('text2text-generation', model=model, tokenizer=tokenizer, repetition_penalty=1.5, device=device) | |
except Exception as e: | |
print(e) | |
enhancer_flux = None | |
return enhancer_flux | |
enhancer_flux = load_prompt_enhancer() | |
def enhance_prompt(input_prompt): | |
result = enhancer_flux("enhance prompt: " + input_prompt, max_length = 256) | |
enhanced_text = result[0]['generated_text'] | |
return enhanced_text | |
load_prompt_enhancer.zerogpu = True | |
change_base_model.zerogpu = True | |
fuse_loras.zerogpu = True |