Spaces:
Sleeping
Sleeping
import os | |
import zipfile | |
from pathlib import Path | |
import gradio as gr | |
from database import ( | |
get_user_credits, | |
update_user_credits, | |
get_lora_models_info, | |
get_user_lora_models | |
) | |
from services.image_generation import generate_image | |
from services.train_lora import lora_pipeline | |
from utils.image_utils import url_to_pil_image | |
from utils.file_utils import load_file_content | |
LORA_MODELS = get_lora_models_info() | |
if not isinstance(LORA_MODELS, list): | |
raise ValueError("Expected loras_models to be a list of dictionaries.") | |
BASE_DIR = Path(__file__).parent | |
LOGIN_CSS_PATH = BASE_DIR / 'static/css/login.css' | |
MAIN_CSS_PATH = BASE_DIR / 'static/css/main.css' | |
LANDING_HTML_PATH = BASE_DIR / 'static/html/landing.html' | |
MAIN_HEADER_PATH = BASE_DIR / 'static/html/main_header.html' | |
LOGIN_CSS = load_file_content(LOGIN_CSS_PATH) | |
MAIN_CSS = load_file_content(MAIN_CSS_PATH) | |
LANDING_PAGE = load_file_content(LANDING_HTML_PATH) | |
MAIN_HEADER = load_file_content(MAIN_HEADER_PATH) | |
def load_user_models(request: gr.Request): | |
user = request.session.get('user') | |
print(user) | |
if user: | |
user_models = get_user_lora_models(user['id']) | |
if user_models: | |
return [(item.get("image_url", "assets/logo.jpg"), item["lora_name"]) for item in user_models] | |
return [] | |
def update_selection(evt: gr.SelectData, gallery_type: str, width, height): | |
if gallery_type == "user": | |
selected_lora = {"lora_name": "custom", "trigger_word": "custom"} | |
else: | |
selected_lora = LORA_MODELS[evt.index] | |
new_placeholder = f"Enter a prompt for {selected_lora['lora_name']}" | |
trigger_word = selected_lora["trigger_word"] | |
updated_text = f"#### Trigger Word: {trigger_word} ✨" | |
if "aspect" in selected_lora: | |
if selected_lora["aspect"] == "portrait": | |
width, height = 768, 1024 | |
elif selected_lora["aspect"] == "landscape": | |
width, height = 1024, 768 | |
return gr.update(placeholder=new_placeholder), updated_text, evt.index, width, height, gallery_type | |
def compress_and_train(request: gr.Request, files, model_name, trigger_word, train_steps, lora_rank, batch_size, learning_rate): | |
if not files: | |
return "No Images. Please, upload some images to start training" | |
user = request.session.get('user') | |
_, training_credits = get_user_credits(user['id']) | |
if training_credits <= 0: | |
raise gr.Error("You ran out of credtis. Please buy more to continue") | |
if not user: | |
raise gr.Error("User not authenticated. Please log in.") | |
user_id = user['id'] | |
# Create a directory in the user's home folder | |
output_dir = os.path.expanduser("~/gradio_training_data") | |
os.makedirs(output_dir, exist_ok=True) | |
# Create a zip file in the output directory | |
zip_path = os.path.join(output_dir, "training_data.zip") | |
with zipfile.ZipFile(zip_path, 'w') as zipf: | |
for file_info in files: | |
file_path = file_info[0] # The first element of the tuple is the file path | |
file_name = os.path.basename(file_path) | |
zipf.write(file_path, file_name) | |
print(f"Zip file created at: {zip_path}") | |
print(f'[INFO] Procesando {trigger_word}') | |
# Now call the train_lora function with the zip file path | |
result = lora_pipeline(user_id, | |
zip_path, | |
model_name, | |
trigger_word=trigger_word, | |
steps=train_steps, | |
lora_rank=lora_rank, | |
batch_size=batch_size, | |
autocaption=True, | |
learning_rate=learning_rate) | |
new_training_credits = training_credits - 1 | |
update_user_credits(user['id'], user['generation_credits'], new_training_credits) | |
# Update session data | |
user['training_credits'] = new_training_credits | |
request.session['user'] = user | |
return gr.Info("Your model is training. In about 20 minutes, it will be ready for you to test in 'Generation"), new_training_credits | |
def run_lora(request: gr.Request, prompt, cfg_scale, steps, selected_index, selected_gallery, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)): | |
user = request.session.get('user') | |
if not user: | |
raise gr.Error("User not authenticated. Please log in.") | |
lora_models = get_user_lora_models(user['id']) | |
print(f'Selected gallery: {selected_gallery}') | |
if selected_gallery == "user": | |
lora_models = get_user_lora_models(user['id']) | |
print('Using user models') | |
else: # public | |
lora_models = get_lora_models_info() | |
print('Using public models') | |
print(f'Selected index: {selected_index}') | |
if selected_index is None: | |
selected_lora = None | |
else: | |
selected_lora = lora_models[selected_index] | |
generation_credits, _ = get_user_credits(user['id']) | |
if selected_lora: | |
print(f"Selected Lora: {selected_lora['lora_name']}") | |
model_name = selected_lora['lora_name'] | |
use_default = False | |
else: | |
model_name = "black-forest-labs/flux-pro" | |
print(f"Using default Lora: {model_name}") | |
use_default = True | |
if generation_credits <= 0: | |
raise gr.Error("Ya no tienes creditos disponibles. Compra para continuar.") | |
image_url = generate_image(model_name, prompt, steps, cfg_scale, width, height, lora_scale, progress, use_default) | |
image = url_to_pil_image(image_url) | |
# Update user's credits | |
new_generation_credits = generation_credits - 1 | |
update_user_credits(user['id'], new_generation_credits, user['train_credits']) | |
# Update session data | |
user['generation_credits'] = new_generation_credits | |
request.session['user'] = user | |
print(f"Generation credits remaining: {new_generation_credits}") | |
return image, new_generation_credits | |
def display_credits(request: gr.Request): | |
user = request.session.get('user') | |
if user: | |
generation_credits, train_credits = get_user_credits(user['id']) | |
return generation_credits, train_credits | |
return 0, 0 | |
def load_greet_and_credits(request: gr.Request): | |
greeting = greet(request) | |
generation_credits, train_credits = display_credits(request) | |
return greeting, generation_credits, train_credits | |
def greet(request: gr.Request): | |
user = request.session.get('user') | |
if user: | |
with gr.Column(): | |
with gr.Row(): | |
greeting = f"Hola 👋 {user['given_name']}!" | |
return f"{greeting}\n" | |
return "OBTU AI. Please log in." | |
with gr.Blocks(theme=gr.themes.Soft(), css=LOGIN_CSS) as login_demo: | |
with gr.Column(elem_id="google-btn-container", elem_classes="google-btn-container svelte-vt1mxs gap"): | |
btn = gr.Button("Sign In with Google", elem_classes="login-with-google-btn") | |
_js_redirect = """ | |
() => { | |
url = '/login' + window.location.search; | |
window.open(url, '_blank'); | |
} | |
""" | |
btn.click(None, js=_js_redirect) | |
gr.HTML(LANDING_PAGE) | |
header = '<script src="https://cdn.lordicon.com/lordicon.js"></script>' | |
with gr.Blocks(theme=gr.themes.Soft(), head=header, css=MAIN_CSS) as main_demo: | |
title = gr.HTML(MAIN_HEADER) | |
with gr.Column(elem_id="logout-btn-container"): | |
gr.Button("Logout", link="/logout", elem_id="logout_btn") | |
greetings = gr.Markdown("Loading user information...") | |
selected_index = gr.State(None) | |
with gr.Row(): | |
with gr.Column(): | |
generation_credits_display = gr.Number(label="Generation Credits", precision=0, interactive=False) | |
with gr.Column(): | |
train_credits_display = gr.Number(label="Training Credits", precision=0, interactive=False) | |
with gr.Column(): | |
gr.Button("Buy Credits 💳", link="/buy_credits") | |
with gr.Tabs(): | |
with gr.TabItem('Create'): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
prompt = gr.Textbox(label="Prompt", | |
lines=1, | |
placeholder="Enter Your Prompt to start creating 📷", | |
info='Some public models may experience longer processing times due to server availability and queue management.') | |
with gr.Column(scale=1, elem_id="gen_column"): | |
generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn") | |
with gr.Row(): | |
with gr.Column(scale=4): | |
result = gr.Image(label="Imagen Generada") | |
with gr.Column(scale=3): | |
with gr.Accordion("Public Models"): | |
selected_info = gr.Markdown("") | |
gallery = gr.Gallery( | |
[(item["image_url"], item["model_name"]) for item in LORA_MODELS], | |
label="Public Models", | |
allow_preview=False, | |
columns=3, | |
elem_id="gallery" | |
) | |
with gr.Accordion("Your Models"): | |
user_model_gallery = gr.Gallery( | |
label="Galeria de Modelos", | |
allow_preview=False, | |
columns=3, | |
elem_id="galley" | |
) | |
gallery_type = gr.State("Public") | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5) | |
steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28) | |
with gr.Row(): | |
width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024) | |
height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024) | |
with gr.Row(): | |
randomize_seed = gr.Checkbox(True, label="Randomize seed") | |
lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.95) | |
gallery.select( | |
update_selection, | |
inputs=[gr.State("public"), width, height], | |
outputs=[prompt, selected_info, selected_index, width, height, gallery_type] | |
) | |
user_model_gallery.select( | |
update_selection, | |
inputs=[gr.State("user"), width, height], | |
outputs=[prompt, selected_info, selected_index, width, height, gallery_type] | |
) | |
gr.on( | |
triggers=[generate_button.click, prompt.submit], | |
fn=run_lora, | |
inputs=[prompt, cfg_scale, steps, selected_index, gallery_type, width, height, lora_scale], | |
outputs=[result, generation_credits_display] | |
) | |
with gr.TabItem("Train"): | |
gr.Markdown("# Train your own model 🧠") | |
gr.Markdown("In this section, you can train your own model using your images.") | |
with gr.Row(): | |
with gr.Column(): | |
train_dataset = gr.Gallery(columns=4, interactive=True, label="Tus Imagenes") | |
model_name = gr.Textbox(label="Model Name",) | |
trigger_word = gr.Textbox(label="Trigger Word", | |
info="This will be a keyword to later instruct the model when to use these new capabilities we're going to teach it", | |
) | |
train_button = gr.Button("Start Training") | |
with gr.Accordion("Advanced Settings", open=False): | |
train_steps = gr.Slider(label="Training Steps", minimum=100, maximum=10000, step=100, value=1000) | |
lora_rank = gr.Number(label='lora_rank', value=16) | |
batch_size = gr.Number(label='batch_size', value=1) | |
learning_rate = gr.Number(label='learning_rate', value=0.0004) | |
training_status = gr.Textbox(label="Training Status") | |
train_button.click( | |
compress_and_train, | |
inputs=[train_dataset, model_name, trigger_word, train_steps, lora_rank, batch_size, learning_rate], | |
outputs=[training_status,train_credits_display] | |
) | |
main_demo.load(load_user_models, None, user_model_gallery) | |
main_demo.load(load_greet_and_credits, None, [greetings, generation_credits_display, train_credits_display]) |