Akjava commited on
Commit
b6517cd
1 Parent(s): e377d12
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -2,14 +2,20 @@ import spaces
2
  import torch
3
  from diffusers import FluxInpaintPipeline
4
 
 
 
 
 
 
 
 
5
  import gradio as gr
6
  import re
7
  from PIL import Image
8
- import flux1_inpaint
9
  import os
10
  import numpy as np
11
- import shutil
12
- #shutil.rmtree("/home/user/app/.gradio/cached_examples/23")
13
 
14
  def sanitize_prompt(prompt):
15
  # Allow only alphanumeric characters, spaces, and basic punctuation
@@ -43,9 +49,7 @@ def process_images(image, image2=None,prompt="a girl",inpaint_model="black-fores
43
  def process_image(image,mask_image,prompt="a person",model_id="black-forest-labs/FLUX.1-schnell",strength=0.75,seed=0,num_inference_steps=4):
44
  if image == None:
45
  return None
46
-
47
- pipe = FluxInpaintPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
48
- pipe.to("cuda")
49
 
50
  generators = []
51
  generator = torch.Generator("cuda").manual_seed(seed)
 
2
  import torch
3
  from diffusers import FluxInpaintPipeline
4
 
5
+ dtype = torch.bfloat16
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
7
+
8
+ pipe = FluxInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(device)
9
+
10
+
11
+
12
  import gradio as gr
13
  import re
14
  from PIL import Image
15
+
16
  import os
17
  import numpy as np
18
+
 
19
 
20
  def sanitize_prompt(prompt):
21
  # Allow only alphanumeric characters, spaces, and basic punctuation
 
49
  def process_image(image,mask_image,prompt="a person",model_id="black-forest-labs/FLUX.1-schnell",strength=0.75,seed=0,num_inference_steps=4):
50
  if image == None:
51
  return None
52
+
 
 
53
 
54
  generators = []
55
  generator = torch.Generator("cuda").manual_seed(seed)