File size: 2,337 Bytes
bbc89f6
 
 
a1e077b
bbc89f6
 
 
a1e077b
bbc89f6
e99d2e7
a1e077b
 
 
bbc89f6
 
 
 
 
 
a1e077b
bbc89f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1e077b
bbc89f6
 
 
 
a1e077b
bbc89f6
 
a1e077b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import replicate
import os
from huggingface_hub import create_repo
from database import create_lora_models

REPLICATE_OWNER = "josebenitezg"

def lora_pipeline(user_id, zip_path, model_name, trigger_word="TOK", steps=1000, lora_rank=16, batch_size=1, autocaption=True, learning_rate=0.0004):
    print(f'Creating dataset for {model_name}')
    model_name = model_name.lower().replace(' ', '_')
    hf_repo_name = f"joselobenitezg/flux-dev-{model_name}"
    replicate_repo_name = f"josebenitezg/flux-dev-{model_name}"
    create_repo(hf_repo_name, repo_type='model')

    lora_name = f"flux-dev-{model_name}"
    
    model = replicate.models.create(
        owner=REPLICATE_OWNER,
        name=lora_name,
        visibility="private",  # or "private" if you prefer
        hardware="gpu-t4",  # Replicate will override this for fine-tuned models
        description="A fine-tuned FLUX.1 model"
    )

    print(f"Model created: {model.name}")
    print(f"Model URL: https://replicate.com/{model.owner}/{model.name}")

    # Now use this model as the destination for your training
    print(f"[INFO] Starting training")
    
    print(f'\n[INFO] Parametros a entrenar: \n Trigger word: {trigger_word}\n steps: {steps} \n lora_rank: {lora_rank}\n autocaption: {autocaption}\n learning_rate: {learning_rate}\n') 
    training = replicate.trainings.create(
        version="ostris/flux-dev-lora-trainer:1296f0ab2d695af5a1b5eeee6e8ec043145bef33f1675ce1a2cdb0f81ec43f02",
        input={
            "input_images": open(zip_path, "rb"),
            "steps": steps,
            "lora_rank": lora_rank,
            "batch_size": batch_size,
            "autocaption": autocaption,
            "trigger_word": trigger_word,
            "learning_rate": learning_rate,
            "hf_token": os.getenv('HF_TOKEN'),  # optional
            "hf_repo_id": hf_repo_name,  # optional
        },
        destination=f"{model.owner}/{model.name}"
    )

    print(f"training: {training.keys()}")
    print(f"Training started: {training.status}")
    print(f"Training URL: https://replicate.com/p/{training.id}")
    print(f"Creating model in Database")
    training_url = f"https://replicate.com/p/{training.id}"
    create_lora_models(user_id, replicate_repo_name, trigger_word, steps, lora_rank, batch_size, learning_rate, hf_repo_name, training_url)