Jose Benitez commited on
Commit
0fb8049
1 Parent(s): 52fe18e

minor fixes

Browse files
Files changed (3) hide show
  1. config.py +3 -0
  2. main.py +0 -2
  3. services/train_lora.py +2 -4
config.py CHANGED
@@ -16,6 +16,9 @@ STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET")
16
 
17
  DOMAIN = os.getenv("DOMAIN")
18
 
 
 
 
19
 
20
  logging.basicConfig(level=logging.INFO)
21
  logger = logging.getLogger(__name__)
 
16
 
17
  DOMAIN = os.getenv("DOMAIN")
18
 
19
+ REPLICATE_OWNER = "josebenitezg"
20
+ HF_OWNER = "joselobenitezg"
21
+
22
 
23
  logging.basicConfig(level=logging.INFO)
24
  logger = logging.getLogger(__name__)
main.py CHANGED
@@ -7,7 +7,6 @@ from routes import router, get_user
7
  from gradio_app import login_demo, main_demo
8
  import gradio as gr
9
  from pathlib import Path
10
- from fastapi.middleware.cors import CORSMiddleware
11
 
12
  app = FastAPI()
13
 
@@ -16,7 +15,6 @@ main_demo.queue()
16
 
17
  static_dir = Path("./static")
18
  app.mount("/static", StaticFiles(directory=static_dir, html=True), name="static")
19
- #app.mount("/assets", StaticFiles(directory="assets", html=True), name="assets")
20
 
21
  app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY, max_age=3600)
22
 
 
7
  from gradio_app import login_demo, main_demo
8
  import gradio as gr
9
  from pathlib import Path
 
10
 
11
  app = FastAPI()
12
 
 
15
 
16
  static_dir = Path("./static")
17
  app.mount("/static", StaticFiles(directory=static_dir, html=True), name="static")
 
18
 
19
  app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY, max_age=3600)
20
 
services/train_lora.py CHANGED
@@ -2,13 +2,12 @@ import replicate
2
  import os
3
  from huggingface_hub import create_repo
4
  from database import create_lora_models
5
-
6
- REPLICATE_OWNER = "josebenitezg"
7
 
8
  def lora_pipeline(user_id, zip_path, model_name, trigger_word="TOK", steps=1000, lora_rank=16, batch_size=1, autocaption=True, learning_rate=0.0004):
9
  print(f'Creating dataset for {model_name}')
10
  model_name = model_name.lower().replace(' ', '_')
11
- hf_repo_name = f"joselobenitezg/flux-dev-{model_name}"
12
  replicate_repo_name = f"josebenitezg/flux-dev-{model_name}"
13
  create_repo(hf_repo_name, repo_type='model')
14
 
@@ -45,7 +44,6 @@ def lora_pipeline(user_id, zip_path, model_name, trigger_word="TOK", steps=1000,
45
  destination=f"{model.owner}/{model.name}"
46
  )
47
 
48
- print(f"training: {training.keys()}")
49
  print(f"Training started: {training.status}")
50
  print(f"Training URL: https://replicate.com/p/{training.id}")
51
  print(f"Creating model in Database")
 
2
  import os
3
  from huggingface_hub import create_repo
4
  from database import create_lora_models
5
+ from config import REPLICATE_OWNER, HF_OWNER
 
6
 
7
  def lora_pipeline(user_id, zip_path, model_name, trigger_word="TOK", steps=1000, lora_rank=16, batch_size=1, autocaption=True, learning_rate=0.0004):
8
  print(f'Creating dataset for {model_name}')
9
  model_name = model_name.lower().replace(' ', '_')
10
+ hf_repo_name = f"{HF_OWNER}/flux-dev-{model_name}"
11
  replicate_repo_name = f"josebenitezg/flux-dev-{model_name}"
12
  create_repo(hf_repo_name, repo_type='model')
13
 
 
44
  destination=f"{model.owner}/{model.name}"
45
  )
46
 
 
47
  print(f"Training started: {training.status}")
48
  print(f"Training URL: https://replicate.com/p/{training.id}")
49
  print(f"Creating model in Database")