import os import time import platform def is_running_in_huggingface_spaces(): return "SPACE_ID" in os.environ import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM try: # переводчик с русского на английский from google_translate import TranslatorWithCache is_google_translate_installed=True translator = TranslatorWithCache() except ImportError: is_google_translate_installed=False try: from config_ui import Config is_config_ui_installed=True config = Config() device = "cuda" if (config.cuda=="cuda" and torch.cuda.is_available()) else "cpu" lang=config.lang except ImportError: is_config_ui_installed=False if platform.system() == "Darwin" and platform.machine().startswith("arm"): print("run on mac with Apple Silicon") if torch.backends.mps.is_available(): device = torch.device("mps") # MPS = Metal Performance Shaders else: #TODO parse env var to assign cuda device device = 0 if torch.cuda.is_available() else "cpu" lang='EN' try: from prompt.portrait_prompt import generate_random_portrait_prompt is_rnd_gen_installed=True except: is_rnd_gen_installed=False model_checkpoint = "gokaygokay/Flux-Prompt-Enhance" tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to(device) max_target_length = 256 prefix = "enhance prompt" def enhance_prompt(prompt, system_prompt, temperature=0.5, repetition_penalty=1.2, seed=-1, is_rnd_seed=True): global lang start_time = time.time() # Начало замера времени if is_rnd_seed or seed==-1: seed = torch.randint(0, 2**32 - 1, (1,)).item() torch.manual_seed(seed) if is_google_translate_installed: # Перевод с русского на английский en_prompt = translator.translate_ru2eng(prompt) input_text = f"{system_prompt}: {en_prompt}" else: input_text = f"{system_prompt}: {prompt}" input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device) # Генерация текста outputs = model.generate( input_ids, max_length=max_target_length, num_return_sequences=1, do_sample=True, temperature=temperature, repetition_penalty=repetition_penalty ) generated_text_en = tokenizer.decode(outputs[0], skip_special_tokens=True) if is_google_translate_installed: result_output_ru = translator.translate_eng2ru(generated_text_en) else: result_output_ru=generated_text_en end_time = time.time() # Конец замера времени execution_time = end_time - start_time time_str=f"execution time: {execution_time:.2f} s." if lang=="EN" else f"время выполнения: {execution_time:.2f} с." return seed, generated_text_en, result_output_ru, time_str def random_prompt(): rnd_prompt_str=generate_random_portrait_prompt() #rnd_prompt_str=get_random_words() return rnd_prompt_str def copy_to_clipboard(text): gr.Info("скопировано в буффер обмена" if (lang=="RU") else "copy to clipboard" ,duration=1) return None LABELS_EN={"prompt_input": "Input initial prompt:", "seed_output": "Seed:", "result_output" : "Improved prompt", "result_output_ru" : "Improved prompt (in Russian)", "generate_button": "Improve prompt", "copy_button": "Copy to clipboard", "save_button": "Save config", "advanced": "Advanced settings:", "system_prompt" : "System prompt:", "temperature": "Temperature", "repetition_penalty": "Repetition penalty", "is_rnd_seed": "Random Seed" } LABELS=LABELS_EN if is_google_translate_installed: LABELS_RU={"prompt_input": "Введите начальный промпт:", "seed_output": "Seed для генерации:", "result_output" : "Улучшенный промпт (на английском):", "result_output_ru" : "Улучшенный промпт (на русском):", "generate_button": "Улучшить промпт", "copy_button": "Скопировать в буффер обмена", "save_button": "Сохранить настройки", "advanced": "Расширенные настройки:", "system_prompt": "Системный промпт:", "temperature": "Температура", "repetition_penalty": "Штраф за повторение", "is_rnd_seed": "Случайный Seed" } LABELS=LABELS_EN if lang=="EN" else LABELS_RU if is_google_translate_installed: def process_lang(selected_lang): global lang lang=selected_lang if selected_lang == "RU": LABELS=LABELS_RU message="Вы выбрали русский" isVisible=True elif selected_lang == "EN": LABELS=LABELS_EN message="You selected English" isVisible=False ret = [gr.update(value=LABELS["generate_button"]), gr.update(value=LABELS["copy_button"]), gr.update(value=LABELS["save_button"]), gr.update(label=LABELS["prompt_input"]), gr.update(label=LABELS["seed_output"]), gr.update(label=LABELS["is_rnd_seed"]), gr.update(label=LABELS["result_output"]), gr.update(visible=isVisible, label=LABELS["result_output_ru"]), gr.update(label=LABELS["advanced"]), gr.update(label=LABELS["system_prompt"]), gr.update(label=LABELS["temperature"]), gr.update(label=LABELS["repetition_penalty"]) ] return message, *ret if is_config_ui_installed: def save_config(): global lang,device,isOpenAdvanced ,config, AccordionAdvanced config.set_lang(lang) config.set_cuda(str(device)) isOpenAdvanced=AccordionAdvanced.open print(AccordionAdvanced.open) config.set_OpenAdvanced=(isOpenAdvanced) # Сохраняем изменения в файл config.save() return "save config to file" if lang=='EN' else "Конфигурация сохранена в файл" def process_gpu(selected_gpu): """Функция для переключения модели между устройствами (CPU / CUDA)""" global model, device # Используем глобальные переменные model и device device = torch.device(selected_gpu) # Устанавливаем новое устройство model = model.to(device) # Переносим модель на новое устройство message= f"Модель переключена на устройство: {selected_gpu}" if lang=="RU" else f"Model switched to device: {selected_gpu}" return message def set_initial(): global device dev="cpu" if str(device) =='cuda': device = torch.cuda.current_device() device_name = torch.cuda.get_device_name(device) device_name = f"GPU: {device_name}" dev="cuda" else: device_name = "running on CPU" return gr.update(value=lang), gr.update(value=dev), f'{device_name}, set to "{lang}" language' # Настройка интерфейса Gradio with gr.Blocks(title="Flux Prompt Enhance", theme=gr.themes.Default(primary_hue=gr.themes.colors.sky, secondary_hue=gr.themes.colors.indigo), analytics_enabled=False, css="footer{display:none !important}") as demo: gr.Image(label="header AiCave", value="./static/ai_cave_title.jpg",height="100%", show_download_button=False, show_label=False, show_share_button=False, interactive=False, show_fullscreen_button=False,) with gr.Row(variant="default"): running_in="for spaces" if is_running_in_huggingface_spaces() else "portable" gr.HTML(f'

Flux Prompt Enhance {running_in} by CaveMan

') with gr.Row(variant="default"): # выбор языка UI radio_lang = gr.Radio(choices = ["RU", "EN"], show_label = False, container = False, type = "value", visible = True if is_google_translate_installed else False) if platform.system() == "Darwin" and platform.machine().startswith("arm") and torch.backends.mps.is_available(): print("radio_gpu mac") radio_gpu = gr.Radio(choices = ["mps","cpu"], show_label = False, container = False, type = "value", visible = True) else: radio_gpu = gr.Radio(choices = ["cuda","cpu"], show_label = False, container = False, type = "value", visible = True if torch.cuda.is_available() else False) save_button = gr.Button(LABELS["save_button"], visible= True if is_config_ui_installed else False) with gr.Row(variant="default"): prompt_input = gr.Textbox(label=LABELS["prompt_input"]) if is_rnd_gen_installed: button_random = gr.Button("", icon="./static/random.png", scale=0, min_width=200) button_random.click(fn=random_prompt, outputs=prompt_input) with gr.Accordion(label=LABELS["advanced"], open=False ) as AccordionAdvanced: with gr.Row(variant="default"): system_prompt = gr.Textbox(label=LABELS["system_prompt"], interactive=False,value=prefix) seed_output = gr.Textbox(label=LABELS["seed_output"], interactive=True,value=502119) is_rnd_seed = gr.Checkbox(value=True, label="Random seed", interactive=True) with gr.Row(variant="default"): temperature = gr.Slider(label=LABELS["temperature"], interactive=True,value=0.7, minimum=0.1,maximum=1,step=0.1) repetition_penalty = gr.Slider(label=LABELS["repetition_penalty"], interactive=True,value=1.2, minimum=0.1,maximum=2,step=0.1) #repetition_penalty = result_output = gr.Textbox(label=LABELS["result_output"], interactive=False) result_output_ru = gr.Textbox(label=LABELS["result_output_ru"], interactive=False, visible = False if lang == "EN" else True) #prompt_input.submit(fn=enhance_prompt, inputs=[prompt_input,system_prompt,temperature,repetition_penalty], outputs=[seed_output, result_output, result_output_ru], show_progress=False) # Кнопка генерации with gr.Row(variant="default"): generate_button = gr.Button(LABELS["generate_button"], variant="primary", size="lg") # Кнопка копирования в буфер обмена copy_button = gr.Button(LABELS["copy_button"], variant="secondary") copy_button.click(fn=copy_to_clipboard, inputs=result_output, outputs=[],js="(text) => navigator.clipboard.writeText(text)") with gr.Row(variant="default"): log_text = gr.Textbox(label="", container=False) if is_config_ui_installed: save_button.click(fn=save_config, inputs=[], outputs=log_text) generate_button.click(fn=enhance_prompt, inputs=[prompt_input,system_prompt,temperature,repetition_penalty,seed_output,is_rnd_seed], outputs=[seed_output, result_output, result_output_ru,log_text]) if is_google_translate_installed: radio_lang.change(process_lang, inputs=radio_lang, outputs=[log_text,generate_button, copy_button, save_button, prompt_input, seed_output, is_rnd_seed, result_output, result_output_ru,AccordionAdvanced,system_prompt, temperature, repetition_penalty]) radio_gpu.change(process_gpu, inputs=radio_gpu, outputs=log_text) #preload values for lang demo.load(set_initial, outputs=[radio_lang, radio_gpu, log_text]) launch_args={} if not is_running_in_huggingface_spaces(): launch_args["share"]=False launch_args["server_name"]="0.0.0.0" launch_args["inbrowser"] = True launch_args["favicon_path"] = "./static/favicon_aicave.png" launch_args["show_api"]=True if os.path.exists("cert.pem") and os.path.exists("key.pem"): launch_args["ssl_certfile"]="cert.pem" launch_args["ssl_keyfile"]="key.pem" launch_args["ssl_verify"]=False demo.launch(**launch_args)