ovi054 commited on
Commit
795c18e
·
verified ·
1 Parent(s): d8b5a54

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -0
app.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import random
4
+ import requests
5
+ import gradio as gr
6
+ import numpy as np
7
+ from PIL import Image
8
+ import replicate
9
+
10
+
11
+ MAX_SEED = np.iinfo(np.int32).max
12
+
13
+
14
+ def predict(replicate_api, prompt, lora_id, lora_scale=0.95, aspect_ratio="1:1", seed=-1, randomize_seed=True, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
15
+
16
+ # Validate API key and prompt
17
+ if not replicate_api or not prompt:
18
+ return "Error: Missing necessary inputs.", -1
19
+
20
+ # Set the seed if randomize_seed is True
21
+ if randomize_seed:
22
+ seed = random.randint(0, MAX_SEED)
23
+
24
+ # Set the Replicate API token in the environment variable
25
+ os.environ["REPLICATE_API_TOKEN"] = replicate_api
26
+
27
+ # Construct the input for the replicate model
28
+ input_params = {
29
+ "prompt": prompt,
30
+ "output_format": "jpg",
31
+ "aspect_ratio": aspect_ratio,
32
+ "num_inference_steps": num_inference_steps,
33
+ "guidance_scale": guidance_scale,
34
+ "seed": seed,
35
+ "disable_safety_checker": True
36
+ }
37
+
38
+ # If lora_id is provided, include it in the input
39
+ if lora_id:
40
+ input_params["hf_lora"] = lora_id
41
+ input_params["lora_scale"] = lora_scale
42
+
43
+ try:
44
+ # Run the model using the user's API token from the environment variable
45
+ output = replicate.run(
46
+ "lucataco/flux-dev-lora:a22c463f11808638ad5e2ebd582e07a469031f48dd567366fb4c6fdab91d614d",
47
+ input=input_params
48
+ )
49
+ return output[0], seed # Return the generated image and seed
50
+
51
+ except Exception as e:
52
+ # Catch any exceptions, such as invalid API token or lack of credits
53
+ return f"Error: {str(e)}", -1
54
+
55
+ finally:
56
+ # Always remove the API key from the environment
57
+ if "REPLICATE_API_TOKEN" in os.environ:
58
+ del os.environ["REPLICATE_API_TOKEN"]
59
+
60
+
61
+
62
+ demo = gr.Interface(fn=predict, inputs="text", outputs="image")
63
+
64
+ css="""
65
+ #col-container {
66
+ margin: 0 auto;
67
+ max-width: 520px;
68
+ }
69
+ """
70
+ examples = [
71
+ "a tiny astronaut hatching from an egg on the moon",
72
+ "a cat holding a sign that says hello world",
73
+ "an anime illustration of a wiener schnitzel",
74
+ ]
75
+
76
+ with gr.Blocks(css=css) as demo:
77
+ with gr.Column(elem_id="col-container"):
78
+ gr.Markdown("# FLUX Dev with LoRA Support using Replicate API")
79
+
80
+ replicate_api = gr.Text(label="Replicate API", show_label=True, max_lines=1, placeholder="Enter Replicate API", container=False)
81
+ prompt = gr.Text(label="Prompt", show_label=True, max_lines=4, show_copy_button = True, placeholder="Enter your prompt", container=False)
82
+ with gr.Accordion("Advanced Settings", open=False):
83
+ with gr.Row():
84
+ custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path (optional)", placeholder="multimodalart/vintage-ads-flux")
85
+ lora_scale = gr.Slider(
86
+ label="LoRA Scale",
87
+ minimum=0,
88
+ maximum=1,
89
+ step=0.1,
90
+ value=0.95,
91
+ )
92
+ aspect_ratio = gr.Radio(label="Aspect ratio", value="1:1", choices=["1:1", "4:5", "2:3", "3:4","9:16", "4:3", "16:9"])
93
+ seed = gr.Slider(
94
+ label="Seed",
95
+ minimum=0,
96
+ maximum=MAX_SEED,
97
+ step=1,
98
+ value=0,
99
+ )
100
+
101
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
102
+
103
+ with gr.Row():
104
+ guidance_scale = gr.Slider(
105
+ label="Guidance Scale",
106
+ minimum=1,
107
+ maximum=15,
108
+ step=0.1,
109
+ value=3.5,
110
+ )
111
+ num_inference_steps = gr.Slider(
112
+ label="Number of inference steps",
113
+ minimum=1,
114
+ maximum=50,
115
+ step=1,
116
+ value=28,
117
+ )
118
+ submit = gr.Button("Generate Image", scale=1)
119
+
120
+ output = gr.Image(label="Output Image", show_label=True)
121
+
122
+
123
+ gr.Examples(
124
+ examples=examples,
125
+ fn=predict,
126
+ inputs=[prompt]
127
+ )
128
+ gr.on(
129
+ triggers=[submit.click, prompt.submit],
130
+ fn=predict,
131
+ inputs=[replicate_api, prompt, custom_lora, lora_scale, aspect_ratio, seed, randomize_seed, guidance_scale, num_inference_steps],
132
+ outputs = [output, seed]
133
+ )
134
+
135
+ demo.launch()