Borcherding commited on
Commit
16eb805
·
verified ·
1 Parent(s): 9261f1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -13
app.py CHANGED
@@ -4,21 +4,38 @@ import spaces
4
  import torch
5
  import random
6
  from peft import PeftModel
7
-
8
  from diffusers import FluxControlPipeline, FluxTransformer2DModel
9
  from image_gen_aux import DepthPreprocessor
10
 
11
  MAX_SEED = np.iinfo(np.int32).max
12
  MAX_IMAGE_SIZE = 2048
13
 
14
- pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-Depth-dev", torch_dtype=torch.bfloat16).to("cuda")
 
 
 
 
 
 
 
15
  processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
16
 
17
  def load_lora(lora_path):
18
  if not lora_path.strip():
19
  return "Please provide a valid LoRA path"
20
  try:
 
 
 
 
 
 
 
21
  pipe.load_lora_weights(lora_path)
 
 
 
 
22
  return f"Successfully loaded LoRA weights from {lora_path}"
23
  except Exception as e:
24
  return f"Error loading LoRA weights: {str(e)}"
@@ -26,6 +43,8 @@ def load_lora(lora_path):
26
  def unload_lora():
27
  try:
28
  pipe.unload_lora_weights()
 
 
29
  return "Successfully unloaded LoRA weights"
30
  except Exception as e:
31
  return f"Error unloading LoRA weights: {str(e)}"
@@ -36,18 +55,25 @@ def infer(control_image, prompt, seed=42, randomize_seed=False, width=1024, heig
36
 
37
  if randomize_seed:
38
  seed = random.randint(0, MAX_SEED)
 
 
 
 
39
 
40
- control_image = processor(control_image)[0].convert("RGB")
41
- image = pipe(
42
- prompt=prompt,
43
- control_image=control_image,
44
- height=height,
45
- width=width,
46
- num_inference_steps=num_inference_steps,
47
- guidance_scale=guidance_scale,
48
- generator=torch.Generator().manual_seed(seed),
49
- ).images[0]
50
- return image, seed
 
 
 
51
 
52
  css="""
53
  #col-container {
 
4
  import torch
5
  import random
6
  from peft import PeftModel
 
7
  from diffusers import FluxControlPipeline, FluxTransformer2DModel
8
  from image_gen_aux import DepthPreprocessor
9
 
10
  MAX_SEED = np.iinfo(np.int32).max
11
  MAX_IMAGE_SIZE = 2048
12
 
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ # Initialize the pipeline and move it to GPU
16
+ pipe = FluxControlPipeline.from_pretrained(
17
+ "black-forest-labs/FLUX.1-Depth-dev",
18
+ torch_dtype=torch.bfloat16
19
+ ).to(device)
20
+
21
  processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
22
 
23
  def load_lora(lora_path):
24
  if not lora_path.strip():
25
  return "Please provide a valid LoRA path"
26
  try:
27
+ # Unload any existing LoRA weights first
28
+ try:
29
+ pipe.unload_lora_weights()
30
+ except:
31
+ pass
32
+
33
+ # Load new LoRA weights and move to the same device
34
  pipe.load_lora_weights(lora_path)
35
+
36
+ # Ensure all model components are on the correct device
37
+ pipe.to(device)
38
+
39
  return f"Successfully loaded LoRA weights from {lora_path}"
40
  except Exception as e:
41
  return f"Error loading LoRA weights: {str(e)}"
 
43
  def unload_lora():
44
  try:
45
  pipe.unload_lora_weights()
46
+ # Ensure model is on correct device after unloading
47
+ pipe.to(device)
48
  return "Successfully unloaded LoRA weights"
49
  except Exception as e:
50
  return f"Error unloading LoRA weights: {str(e)}"
 
55
 
56
  if randomize_seed:
57
  seed = random.randint(0, MAX_SEED)
58
+
59
+ try:
60
+ # Process control image
61
+ control_image = processor(control_image)[0].convert("RGB")
62
 
63
+ # Generate image
64
+ image = pipe(
65
+ prompt=prompt,
66
+ control_image=control_image,
67
+ height=height,
68
+ width=width,
69
+ num_inference_steps=num_inference_steps,
70
+ guidance_scale=guidance_scale,
71
+ generator=torch.Generator(device=device).manual_seed(seed),
72
+ ).images[0]
73
+
74
+ return image, seed
75
+ except Exception as e:
76
+ return None, f"Error during inference: {str(e)}"
77
 
78
  css="""
79
  #col-container {