fffiloni commited on
Commit
fff8451
·
verified ·
1 Parent(s): cdcfdd8

Create gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +118 -0
gradio_app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import imageio
3
+ import os
4
+ import gradio as gr
5
+ from diffusers.schedulers import EulerAncestralDiscreteScheduler
6
+ from transformers import T5EncoderModel, T5Tokenizer
7
+ from allegro.pipelines.pipeline_allegro import AllegroPipeline
8
+ from allegro.models.vae.vae_allegro import AllegroAutoencoderKL3D
9
+ from allegro.models.transformers.transformer_3d_allegro import AllegroTransformer3DModel
10
+
11
+ import huggingface_hub
12
+
13
+ weights_dir = './allegro_weights'
14
+ os.makedirs(weights_dir, exist_ok=True)
15
+
16
+ huggingface_hub.snapshot_download(
17
+ repo_id='rhymes-ai/Allegro',
18
+ allow_patterns=[
19
+ 'scheduler/**',
20
+ 'text_encoder/**',
21
+ 'tokenizer/**',
22
+ 'transformer/**',
23
+ 'vae/**',
24
+ ],
25
+ local_dir=weights_dir,
26
+ local_dir_use_symlinks=False,
27
+ )
28
+
29
+
30
+ def single_inference(user_prompt, save_path, guidance_scale, num_sampling_steps, seed, enable_cpu_offload):
31
+ dtype = torch.bfloat16
32
+
33
+ # Load models
34
+ vae = AllegroAutoencoderKL3D.from_pretrained(
35
+ "weights_dir/vae/",
36
+ torch_dtype=torch.float32
37
+ ).cuda()
38
+ vae.eval()
39
+
40
+ text_encoder = T5EncoderModel.from_pretrained("weights_dir/text_encoder/", torch_dtype=dtype)
41
+ text_encoder.eval()
42
+
43
+ tokenizer = T5Tokenizer.from_pretrained("weights_dir/tokenizer/")
44
+
45
+ scheduler = EulerAncestralDiscreteScheduler()
46
+
47
+ transformer = AllegroTransformer3DModel.from_pretrained("weights_dir/transformer/", torch_dtype=dtype).cuda()
48
+ transformer.eval()
49
+
50
+ allegro_pipeline = AllegroPipeline(
51
+ vae=vae,
52
+ text_encoder=text_encoder,
53
+ tokenizer=tokenizer,
54
+ scheduler=scheduler,
55
+ transformer=transformer
56
+ ).to("cuda:0")
57
+
58
+ positive_prompt = """
59
+ (masterpiece), (best quality), (ultra-detailed), (unwatermarked),
60
+ {}
61
+ emotional, harmonious, vignette, 4k epic detailed, shot on kodak, 35mm photo,
62
+ sharp focus, high budget, cinemascope, moody, epic, gorgeous
63
+ """
64
+
65
+ negative_prompt = """
66
+ nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality,
67
+ low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry.
68
+ """
69
+
70
+ # Process user prompt
71
+ user_prompt = positive_prompt.format(user_prompt.lower().strip())
72
+
73
+ if enable_cpu_offload:
74
+ allegro_pipeline.enable_sequential_cpu_offload()
75
+
76
+ out_video = allegro_pipeline(
77
+ user_prompt,
78
+ negative_prompt=negative_prompt,
79
+ num_frames=88,
80
+ height=720,
81
+ width=1280,
82
+ num_inference_steps=num_sampling_steps,
83
+ guidance_scale=guidance_scale,
84
+ max_sequence_length=512,
85
+ generator=torch.Generator(device="cuda:0").manual_seed(seed)
86
+ ).video[0]
87
+
88
+ # Save video
89
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
90
+ imageio.mimwrite(save_path, out_video, fps=15, quality=8)
91
+
92
+ return save_path
93
+
94
+
95
+ # Gradio interface function
96
+ def run_inference(user_prompt, guidance_scale, num_sampling_steps, seed, enable_cpu_offload):
97
+ save_path = "./output_videos/generated_video.mp4"
98
+ result_path = single_inference(user_prompt, save_path, guidance_scale, num_sampling_steps, seed, enable_cpu_offload)
99
+ return result_path
100
+
101
+
102
+ # Create Gradio interface
103
+ iface = gr.Interface(
104
+ fn=run_inference,
105
+ inputs=[
106
+ gr.Textbox(label="User Prompt"),
107
+ gr.Slider(minimum=0, maximum=20, step=0.1, label="Guidance Scale", value=7.5),
108
+ gr.Slider(minimum=10, maximum=200, step=1, label="Number of Sampling Steps", value=100),
109
+ gr.Slider(minimum=0, maximum=10000, step=1, label="Random Seed", value=42),
110
+ gr.Checkbox(label="Enable CPU Offload", value=False),
111
+ ],
112
+ outputs=gr.Video(label="Generated Video"),
113
+ title="Allegro Video Generation",
114
+ description="Generate a video based on a text prompt using the Allegro pipeline."
115
+ )
116
+
117
+ # Launch the interface
118
+ iface.launch()