import os import time import platform def is_running_in_huggingface_spaces(): return "SPACE_ID" in os.environ import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM try: # переводчик с русского на английский from google_translate import TranslatorWithCache is_google_translate_installed=True translator = TranslatorWithCache() except ImportError: is_google_translate_installed=False try: from config_ui import Config is_config_ui_installed=True config = Config() device = "cuda" if (config.cuda=="cuda" and torch.cuda.is_available()) else "cpu" lang=config.lang except ImportError: is_config_ui_installed=False if platform.system() == "Darwin" and platform.machine().startswith("arm"): print("run on mac with Apple Silicon") if torch.backends.mps.is_available(): device = torch.device("mps") # MPS = Metal Performance Shaders else: #TODO parse env var to assign cuda device device = 0 if torch.cuda.is_available() else "cpu" lang='EN' try: from prompt.portrait_prompt import generate_random_portrait_prompt is_rnd_gen_installed=True except: is_rnd_gen_installed=False model_checkpoint = "gokaygokay/Flux-Prompt-Enhance" tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to(device) max_target_length = 256 prefix = "enhance prompt" def enhance_prompt(prompt, system_prompt, temperature=0.5, repetition_penalty=1.2, seed=-1, is_rnd_seed=True): global lang start_time = time.time() # Начало замера времени if is_rnd_seed or seed==-1: seed = torch.randint(0, 2**32 - 1, (1,)).item() torch.manual_seed(seed) if is_google_translate_installed: # Перевод с русского на английский en_prompt = translator.translate_ru2eng(prompt) input_text = f"{system_prompt}: {en_prompt}" else: input_text = f"{system_prompt}: {prompt}" input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device) # Генерация текста outputs = model.generate( input_ids, max_length=max_target_length, num_return_sequences=1, do_sample=True, temperature=temperature, repetition_penalty=repetition_penalty ) generated_text_en = tokenizer.decode(outputs[0], skip_special_tokens=True) if is_google_translate_installed: result_output_ru = translator.translate_eng2ru(generated_text_en) else: result_output_ru=generated_text_en end_time = time.time() # Конец замера времени execution_time = end_time - start_time time_str=f"execution time: {execution_time:.2f} s." if lang=="EN" else f"время выполнения: {execution_time:.2f} с." return seed, generated_text_en, result_output_ru, time_str def random_prompt(): rnd_prompt_str=generate_random_portrait_prompt() #rnd_prompt_str=get_random_words() return rnd_prompt_str def copy_to_clipboard(text): gr.Info("скопировано в буффер обмена" if (lang=="RU") else "copy to clipboard" ,duration=1) return None LABELS_EN={"prompt_input": "Input initial prompt:", "seed_output": "Seed:", "result_output" : "Improved prompt", "result_output_ru" : "Improved prompt (in Russian)", "generate_button": "Improve prompt", "copy_button": "Copy to clipboard", "save_button": "Save config", "advanced": "Advanced settings:", "system_prompt" : "System prompt:", "temperature": "Temperature", "repetition_penalty": "Repetition penalty", "is_rnd_seed": "Random Seed" } LABELS=LABELS_EN if is_google_translate_installed: LABELS_RU={"prompt_input": "Введите начальный промпт:", "seed_output": "Seed для генерации:", "result_output" : "Улучшенный промпт (на английском):", "result_output_ru" : "Улучшенный промпт (на русском):", "generate_button": "Улучшить промпт", "copy_button": "Скопировать в буффер обмена", "save_button": "Сохранить настройки", "advanced": "Расширенные настройки:", "system_prompt": "Системный промпт:", "temperature": "Температура", "repetition_penalty": "Штраф за повторение", "is_rnd_seed": "Случайный Seed" } LABELS=LABELS_EN if lang=="EN" else LABELS_RU if is_google_translate_installed: def process_lang(selected_lang): global lang lang=selected_lang if selected_lang == "RU": LABELS=LABELS_RU message="Вы выбрали русский" isVisible=True elif selected_lang == "EN": LABELS=LABELS_EN message="You selected English" isVisible=False ret = [gr.update(value=LABELS["generate_button"]), gr.update(value=LABELS["copy_button"]), gr.update(value=LABELS["save_button"]), gr.update(label=LABELS["prompt_input"]), gr.update(label=LABELS["seed_output"]), gr.update(label=LABELS["is_rnd_seed"]), gr.update(label=LABELS["result_output"]), gr.update(visible=isVisible, label=LABELS["result_output_ru"]), gr.update(label=LABELS["advanced"]), gr.update(label=LABELS["system_prompt"]), gr.update(label=LABELS["temperature"]), gr.update(label=LABELS["repetition_penalty"]) ] return message, *ret if is_config_ui_installed: def save_config(): global lang,device,isOpenAdvanced ,config, AccordionAdvanced config.set_lang(lang) config.set_cuda(str(device)) isOpenAdvanced=AccordionAdvanced.open print(AccordionAdvanced.open) config.set_OpenAdvanced=(isOpenAdvanced) # Сохраняем изменения в файл config.save() return "save config to file" if lang=='EN' else "Конфигурация сохранена в файл" def process_gpu(selected_gpu): """Функция для переключения модели между устройствами (CPU / CUDA)""" global model, device # Используем глобальные переменные model и device device = torch.device(selected_gpu) # Устанавливаем новое устройство model = model.to(device) # Переносим модель на новое устройство message= f"Модель переключена на устройство: {selected_gpu}" if lang=="RU" else f"Model switched to device: {selected_gpu}" return message def set_initial(): global device dev="cpu" if str(device) =='cuda': device = torch.cuda.current_device() device_name = torch.cuda.get_device_name(device) device_name = f"GPU: {device_name}" dev="cuda" else: device_name = "running on CPU" return gr.update(value=lang), gr.update(value=dev), f'{device_name}, set to "{lang}" language' # Настройка интерфейса Gradio with gr.Blocks(title="Flux Prompt Enhance", theme=gr.themes.Default(primary_hue=gr.themes.colors.sky, secondary_hue=gr.themes.colors.indigo), analytics_enabled=False, css="footer{display:none !important}") as demo: gr.Image(label="header AiCave", value="./static/ai_cave_title.jpg",height="100%", show_download_button=False, show_label=False, show_share_button=False, interactive=False, show_fullscreen_button=False,) with gr.Row(variant="default"): running_in="for spaces" if is_running_in_huggingface_spaces() else "portable" gr.HTML(f'