Ffftdtd5dtft commited on
Commit
6bbb3ed
·
verified ·
1 Parent(s): 2f879d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -0
app.py CHANGED
@@ -7,6 +7,7 @@ from diffusers import (
7
  StableDiffusionImg2ImgPipeline,
8
  FluxPipeline,
9
  DiffusionPipeline,
 
10
  )
11
  from transformers import (
12
  pipeline as transformers_pipeline,
@@ -383,6 +384,48 @@ def retrain_models():
383
  pass
384
 
385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  text_to_image_pipeline = get_model_or_download(
387
  "stabilityai/stable-diffusion-2",
388
  "diffusers/text_to_image_model",
@@ -517,6 +560,23 @@ gemma_2_27b_it_pipeline = transformers_pipeline(
517
  model="google/gemma-2-27b-it",
518
  model_kwargs={"torch_dtype": torch.bfloat16},
519
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
 
521
  tools = []
522
 
@@ -686,11 +746,14 @@ gemma_2_2b_it_tab = gr.Interface(
686
  outputs=gr.Textbox(label="Gemma 2 2B IT Response:"),
687
  title="Gemma 2 2B IT",
688
  )
 
 
689
  def generate_gemma_2_27b(prompt):
690
  input_ids = gemma_2_27b_tokenizer(prompt, return_tensors="pt")
691
  outputs = gemma_2_27b_model.generate(**input_ids, max_new_tokens=32)
692
  return gemma_2_27b_tokenizer.decode(outputs[0])
693
 
 
694
  gemma_2_27b_tab = gr.Interface(
695
  fn=generate_gemma_2_27b,
696
  inputs=[gr.Textbox(label="Prompt:")],
@@ -703,6 +766,21 @@ gemma_2_27b_it_tab = gr.Interface(
703
  outputs=gr.Textbox(label="Gemma 2 27B IT Response:"),
704
  title="Gemma 2 27B IT",
705
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
706
 
707
  app = gr.TabbedInterface(
708
  [
@@ -734,6 +812,8 @@ app = gr.TabbedInterface(
734
  gemma_2_2b_it_tab,
735
  gemma_2_27b_tab,
736
  gemma_2_27b_it_tab,
 
 
737
  ],
738
  [
739
  "Generate Image",
@@ -764,6 +844,8 @@ app = gr.TabbedInterface(
764
  "Gemma 2 2B IT",
765
  "Gemma 2 27B",
766
  "Gemma 2 27B IT",
 
 
767
  ],
768
  )
769
 
 
7
  StableDiffusionImg2ImgPipeline,
8
  FluxPipeline,
9
  DiffusionPipeline,
10
+ DPMSolverMultistepScheduler,
11
  )
12
  from transformers import (
13
  pipeline as transformers_pipeline,
 
384
  pass
385
 
386
 
387
+ def generate_text_to_video_ms_1_7b(prompt, num_frames=200):
388
+ blob_name = f"diffusers/text_to_video_ms_1_7b:{prompt}:{num_frames}"
389
+ video_bytes = load_object_from_gcs(blob_name)
390
+ if not video_bytes:
391
+ try:
392
+ with tqdm(total=1, desc="Generating video") as pbar:
393
+ video_frames = text_to_video_ms_1_7b_pipeline(
394
+ prompt, num_inference_steps=25, num_frames=num_frames
395
+ ).frames
396
+ pbar.update(1)
397
+ video_path = export_to_video(video_frames)
398
+ with open(video_path, "rb") as f:
399
+ video_bytes = f.read()
400
+ save_object_to_gcs(blob_name, video_bytes)
401
+ os.remove(video_path)
402
+ except Exception as e:
403
+ print(f"Failed to generate video: {e}")
404
+ return None
405
+ return video_bytes
406
+
407
+
408
+ def generate_text_to_video_ms_1_7b_short(prompt):
409
+ blob_name = f"diffusers/text_to_video_ms_1_7b_short:{prompt}"
410
+ video_bytes = load_object_from_gcs(blob_name)
411
+ if not video_bytes:
412
+ try:
413
+ with tqdm(total=1, desc="Generating short video") as pbar:
414
+ video_frames = text_to_video_ms_1_7b_short_pipeline(
415
+ prompt, num_inference_steps=25
416
+ ).frames
417
+ pbar.update(1)
418
+ video_path = export_to_video(video_frames)
419
+ with open(video_path, "rb") as f:
420
+ video_bytes = f.read()
421
+ save_object_to_gcs(blob_name, video_bytes)
422
+ os.remove(video_path)
423
+ except Exception as e:
424
+ print(f"Failed to generate short video: {e}")
425
+ return None
426
+ return video_bytes
427
+
428
+
429
  text_to_image_pipeline = get_model_or_download(
430
  "stabilityai/stable-diffusion-2",
431
  "diffusers/text_to_image_model",
 
560
  model="google/gemma-2-27b-it",
561
  model_kwargs={"torch_dtype": torch.bfloat16},
562
  )
563
+ text_to_video_ms_1_7b_pipeline = DiffusionPipeline.from_pretrained(
564
+ "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16"
565
+ )
566
+ text_to_video_ms_1_7b_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
567
+ text_to_video_ms_1_7b_pipeline.scheduler.config
568
+ )
569
+ text_to_video_ms_1_7b_pipeline.enable_model_cpu_offload()
570
+ text_to_video_ms_1_7b_pipeline.enable_vae_slicing()
571
+ text_to_video_ms_1_7b_short_pipeline = DiffusionPipeline.from_pretrained(
572
+ "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16"
573
+ )
574
+ text_to_video_ms_1_7b_short_pipeline.scheduler = (
575
+ DPMSolverMultistepScheduler.from_config(
576
+ text_to_video_ms_1_7b_short_pipeline.scheduler.config
577
+ )
578
+ )
579
+ text_to_video_ms_1_7b_short_pipeline.enable_model_cpu_offload()
580
 
581
  tools = []
582
 
 
746
  outputs=gr.Textbox(label="Gemma 2 2B IT Response:"),
747
  title="Gemma 2 2B IT",
748
  )
749
+
750
+
751
  def generate_gemma_2_27b(prompt):
752
  input_ids = gemma_2_27b_tokenizer(prompt, return_tensors="pt")
753
  outputs = gemma_2_27b_model.generate(**input_ids, max_new_tokens=32)
754
  return gemma_2_27b_tokenizer.decode(outputs[0])
755
 
756
+
757
  gemma_2_27b_tab = gr.Interface(
758
  fn=generate_gemma_2_27b,
759
  inputs=[gr.Textbox(label="Prompt:")],
 
766
  outputs=gr.Textbox(label="Gemma 2 27B IT Response:"),
767
  title="Gemma 2 27B IT",
768
  )
769
+ text_to_video_ms_1_7b_tab = gr.Interface(
770
+ fn=generate_text_to_video_ms_1_7b,
771
+ inputs=[
772
+ gr.Textbox(label="Prompt:"),
773
+ gr.Slider(50, 200, 200, step=1, label="Number of Frames:"),
774
+ ],
775
+ outputs=gr.Video(),
776
+ title="Text to Video MS 1.7B",
777
+ )
778
+ text_to_video_ms_1_7b_short_tab = gr.Interface(
779
+ fn=generate_text_to_video_ms_1_7b_short,
780
+ inputs=[gr.Textbox(label="Prompt:")],
781
+ outputs=gr.Video(),
782
+ title="Text to Video MS 1.7B Short",
783
+ )
784
 
785
  app = gr.TabbedInterface(
786
  [
 
812
  gemma_2_2b_it_tab,
813
  gemma_2_27b_tab,
814
  gemma_2_27b_it_tab,
815
+ text_to_video_ms_1_7b_tab,
816
+ text_to_video_ms_1_7b_short_tab,
817
  ],
818
  [
819
  "Generate Image",
 
844
  "Gemma 2 2B IT",
845
  "Gemma 2 27B",
846
  "Gemma 2 27B IT",
847
+ "Text to Video MS 1.7B",
848
+ "Text to Video MS 1.7B Short",
849
  ],
850
  )
851