sayakpaul HF staff commited on
Commit
80ac042
1 Parent(s): 3558706

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +106 -0
README.md ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: diffusers
3
+ license: other
4
+ license_name: flux-1-dev-non-commercial-license
5
+ license_link: LICENSE.md
6
+ ---
7
+
8
+ > [!NOTE]
9
+ > Contains the NF4 checkpoints (`transformer` and `text_encoder_2`) of [`black-forest-labs/FLUX.1-Depth-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev). Please adhere to the original model licensing!
10
+
11
+ <details>
12
+ <summary>Code</summary>
13
+
14
+ ```py
15
+ # !pip install git+https://github.com/asomoza/image_gen_aux.git
16
+ from diffusers import DiffusionPipeline, FluxControlPipeline, FluxTransformer2DModel
17
+ import torch
18
+ from transformers import T5EncoderModel
19
+ from image_gen_aux import DepthPreprocessor
20
+ from diffusers.utils import load_image
21
+ import fire
22
+
23
+
24
+ def load_pipeline(four_bit=False):
25
+ orig_pipeline = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
26
+ if four_bit:
27
+ print("Using four bit.")
28
+ transformer = FluxTransformer2DModel.from_pretrained(
29
+ "sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="transformer", torch_dtype=torch.bfloat16
30
+ )
31
+ text_encoder_2 = T5EncoderModel.from_pretrained(
32
+ "sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="text_encoder_2", torch_dtype=torch.bfloat16
33
+ )
34
+ pipeline = FluxControlPipeline.from_pipe(
35
+ orig_pipeline, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=torch.bfloat16
36
+ )
37
+ else:
38
+ transformer = FluxTransformer2DModel.from_pretrained(
39
+ "black-forest-labs/FLUX.1-Depth-dev",
40
+ subfolder="transformer",
41
+ revision="refs/pr/1",
42
+ torch_dtype=torch.bfloat16,
43
+ )
44
+ pipeline = FluxControlPipeline.from_pipe(orig_pipeline, transformer=transformer, torch_dtype=torch.bfloat16)
45
+
46
+ pipeline.enable_model_cpu_offload()
47
+ return pipeline
48
+
49
+ @torch.no_grad()
50
+ def get_depth(control_image):
51
+ processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
52
+ control_image = processor(control_image)[0].convert("RGB")
53
+ return control_image
54
+
55
+ def load_conditions():
56
+ prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
57
+ control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
58
+ control_image = get_depth(control_image)
59
+ return prompt, control_image
60
+
61
+
62
+ def main(four_bit: bool = False):
63
+ ckpt_id = "sayakpaul/FLUX.1-Depth-dev-nf4"
64
+ pipe = load_pipeline(ckpt_id=ckpt_id, four_bit=four_bit)
65
+ prompt, control_image = load_conditions()
66
+ image = pipe(
67
+ prompt=prompt,
68
+ control_image=control_image,
69
+ height=1024,
70
+ width=1024,
71
+ num_inference_steps=30,
72
+ guidance_scale=10.0,
73
+ max_sequence_length=512,
74
+ generator=torch.Generator("cpu").manual_seed(0),
75
+ ).images[0]
76
+ filename = "output_" + ckpt_id.split("/")[-1].replace(".", "_")
77
+ filename += "_4bit" if four_bit else ""
78
+ image.save(f"{filename}.png")
79
+
80
+
81
+ if __name__ == "__main__":
82
+ fire.Fire(main)
83
+ ```
84
+
85
+ </details>
86
+
87
+ ## Outputs
88
+
89
+ <table>
90
+ <thead>
91
+ <tr>
92
+ <th>Original</th>
93
+ <th>NF4</th>
94
+ </tr>
95
+ </thead>
96
+ <tbody>
97
+ <tr>
98
+ <td>
99
+ <img src="./assets/output_FLUX_1-Fill-dev.png" alt="Original">
100
+ </td>
101
+ <td>
102
+ <img src="./assets/output_FLUX_1-Fill-dev_4bit.png" alt="NF4">
103
+ </td>
104
+ </tr>
105
+ </tbody>
106
+ </table>