Jose Benitez commited on
Commit
b16249c
1 Parent(s): a1e077b

add credits function to training

Browse files
Files changed (2) hide show
  1. gradio_app.py +25 -4
  2. services/image_generation.py +1 -0
gradio_app.py CHANGED
@@ -67,6 +67,12 @@ def compress_and_train(request: gr.Request, files, model_name, trigger_word, tra
67
  return "No hay imágenes. Sube algunas imágenes para poder entrenar."
68
 
69
  user = request.session.get('user')
 
 
 
 
 
 
70
  if not user:
71
  raise gr.Error("User not authenticated. Please log in.")
72
 
@@ -98,7 +104,14 @@ def compress_and_train(request: gr.Request, files, model_name, trigger_word, tra
98
  autocaption=True,
99
  learning_rate=learning_rate)
100
 
101
- return gr.Info("Tu modelo esta entrenando, En unos 20 minutos estará listo para que lo pruebes en 'Generación'.")
 
 
 
 
 
 
 
102
 
103
  def run_lora(request: gr.Request, prompt, cfg_scale, steps, selected_index, selected_gallery, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
104
  user = request.session.get('user')
@@ -278,11 +291,19 @@ with gr.Blocks(theme=gr.themes.Soft(), head=header, css=main_css) as main_demo:
278
  batch_size = gr.Number(label='batch_size', value=1)
279
  learning_rate = gr.Number(label='learning_rate', value=0.0004)
280
  training_status = gr.Textbox(label="Training Status")
281
-
 
 
 
 
 
 
 
282
  train_button.click(
283
- compress_and_train,
 
284
  inputs=[train_dataset, model_name, trigger_word, train_steps, lora_rank, batch_size, learning_rate],
285
- outputs=training_status
286
  )
287
 
288
 
 
67
  return "No hay imágenes. Sube algunas imágenes para poder entrenar."
68
 
69
  user = request.session.get('user')
70
+
71
+ _, training_credits = get_user_credits(user['id'])
72
+
73
+ if training_credits <= 0:
74
+ raise gr.Error("Ya no tienes creditos disponibles. Compra para continuar.")
75
+
76
  if not user:
77
  raise gr.Error("User not authenticated. Please log in.")
78
 
 
104
  autocaption=True,
105
  learning_rate=learning_rate)
106
 
107
+ new_training_credits = training_credits - 1
108
+ update_user_credits(user['id'], user['generation_credits'], new_training_credits)
109
+
110
+ # Update session data
111
+ user['training_credits'] = new_training_credits
112
+ request.session['user'] = user
113
+
114
+ return gr.Info("Tu modelo esta entrenando, En unos 20 minutos estará listo para que lo pruebes en 'Generación'."), new_training_credits
115
 
116
  def run_lora(request: gr.Request, prompt, cfg_scale, steps, selected_index, selected_gallery, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
117
  user = request.session.get('user')
 
291
  batch_size = gr.Number(label='batch_size', value=1)
292
  learning_rate = gr.Number(label='learning_rate', value=0.0004)
293
  training_status = gr.Textbox(label="Training Status")
294
+
295
+ def fake_train(train_dataset, model_name, trigger_word, train_steps, lora_rank, batch_size, learning_rate):
296
+ print(f'fake training for test')
297
+ new_training_credits = 0
298
+ if new_training_credits <= 0:
299
+ raise gr.Error("Ya no tienes creditos disponibles. Compra para continuar.")
300
+ return gr.Info("Tu modelo esta entrenando, En unos 20 minutos estará listo para que lo pruebes en 'Generación'."), new_training_credits
301
+
302
  train_button.click(
303
+ #compress_and_train,
304
+ fake_train,
305
  inputs=[train_dataset, model_name, trigger_word, train_steps, lora_rank, batch_size, learning_rate],
306
+ outputs=[training_status,train_credits_display]
307
  )
308
 
309
 
services/image_generation.py CHANGED
@@ -19,6 +19,7 @@ def generate_image(model_name, prompt, steps, cfg_scale, width, height, lora_sca
19
  }
20
  )
21
  else:
 
22
  img_url = replicate.run(
23
  model_name,
24
  input={
 
19
  }
20
  )
21
  else:
22
+ model_name = model_name.lower().replace(' ', '_')
23
  img_url = replicate.run(
24
  model_name,
25
  input={