Lisandro commited on
Commit
a60b83b
·
1 Parent(s): 5c6edfb

feat: Integrate FLUX.1 [schnell] model with Gradio for image generation

Browse files
Files changed (1) hide show
  1. app.py +30 -23
app.py CHANGED
@@ -1,32 +1,39 @@
1
  import gradio as gr
2
- import numpy as np
3
- import random
4
- import spaces
5
- import torch
6
- from diffusers import DiffusionPipeline
7
 
8
- dtype = torch.bfloat16
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, revision="refs/pr/1").to(device)
 
 
12
 
13
- MAX_SEED = np.iinfo(np.int32).max
14
- MAX_IMAGE_SIZE = 2048
15
 
16
- @spaces.GPU()
17
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
18
- if randomize_seed:
19
- seed = random.randint(0, MAX_SEED)
20
- generator = torch.Generator().manual_seed(seed)
21
- image = pipe(
22
- prompt = prompt,
23
- width = width,
24
- height = height,
25
- num_inference_steps = num_inference_steps,
26
- generator = generator,
27
- guidance_scale=0.0
28
- ).images[0]
29
- return image, seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  examples = [
32
  "a tiny astronaut hatching from an egg on the moon",
 
1
  import gradio as gr
2
+ from gradio_client import Client, handle_file
 
 
 
 
3
 
 
 
4
 
5
+ flux_1_schnell_space = "https://black-forest-labs-flux-1-schnell.hf.space"
6
+ client = None
7
+ job = None
8
 
 
 
9
 
 
10
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
11
+ global job
12
+ global client
13
+ if client is None:
14
+ try:
15
+ client = Client(flux_1_schnell_space)
16
+ print(f"Loaded custom model from {flux_1_schnell_space}")
17
+ except ValueError as e:
18
+ print(f"Failed to load custom model: {e}")
19
+ client = None
20
+ raise gr.Error("Failed to load client for " + flux_1_schnell_space)
21
+
22
+ try:
23
+ job = client.submit(
24
+ prompt=prompt,
25
+ seed=seed,
26
+ randomize_seed=randomize_seed,
27
+ width=width,
28
+ height=height,
29
+ num_inference_steps=num_inference_steps,
30
+ api_name="/infer"
31
+ )
32
+ result = job.result()
33
+ except ValueError as e:
34
+ raise gr.Error(e)
35
+
36
+ return result
37
 
38
  examples = [
39
  "a tiny astronaut hatching from an egg on the moon",