Ricercar commited on
Commit
7824fd4
1 Parent(s): b83a9ff

update arxiv

Browse files
Files changed (2) hide show
  1. README.md +4 -4
  2. app.py +16 -4
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
  title: Diffusion Cocktail
3
  emoji: 🍸
4
- colorFrom: indigo
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 4.9.0
8
  app_file: app.py
9
  pinned: false
10
- python: 3.8
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Diffusion Cocktail
3
  emoji: 🍸
4
+ colorFrom: orange
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 4.7.1
8
  app_file: app.py
9
  pinned: false
10
+ python: 3.9.17
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -6,6 +6,7 @@ import torch
6
  import torchvision.transforms as T
7
 
8
  from clip_interrogator import Config, Interrogator
 
9
 
10
  from ditail import DitailDemo, seed_everything
11
 
@@ -74,6 +75,9 @@ class WebApp():
74
  gtag('config', '{self.gtag}');
75
  }}
76
  """
 
 
 
77
 
78
  self.debug_mode = debug_mode # turn off clip interrogator when debugging for faster building speed
79
  if not self.debug_mode:
@@ -81,7 +85,6 @@ class WebApp():
81
 
82
 
83
  def init_interrogator(self):
84
- # init clip interrogator
85
  config = Config()
86
  config.clip_model_name = self.args_base['clip_model_name']
87
  config.caption_model_name = self.args_base['caption_model_name']
@@ -89,16 +92,25 @@ class WebApp():
89
  self.ci.config.chunk_size = 2048 if self.ci.config.clip_model_name == "ViT-L-14/openai" else 1024
90
  self.ci.config.flavor_intermediate_count = 2048 if self.ci.config.clip_model_name == "ViT-L-14/openai" else 1024
91
 
 
 
 
 
 
 
 
 
 
92
  def title(self):
93
  gr.HTML(
94
  """
95
  <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
96
  <div>
97
  <h1 >Diffusion Cocktail 🍸: Fused Generation from Diffusion Models</h1>
98
- <div style="display: flex; justify-content: center; align-items: center; text-align: center; margin: 20px; gap: 10px;>
99
- <a class="flex-item" href="https://arxiv.org/abs/your-arxiv-id" target="_blank">
100
  <img src="https://img.shields.io/badge/arXiv-paper-darkred.svg" alt="arXiv Paper">
101
- </a>
102
  <a class="flex-item" href="https://MAPS-research.github.io/Ditail" target="_blank">
103
  <img src="https://img.shields.io/badge/Project_Page-Diffusion_Cocktail-yellow.svg" alt="Project Page">
104
  </a>
 
6
  import torchvision.transforms as T
7
 
8
  from clip_interrogator import Config, Interrogator
9
+ from diffusers import StableDiffusionPipeline
10
 
11
  from ditail import DitailDemo, seed_everything
12
 
 
75
  gtag('config', '{self.gtag}');
76
  }}
77
  """
78
+
79
+ # pre-download base model for better user experience
80
+ self._preload_pipeline()
81
 
82
  self.debug_mode = debug_mode # turn off clip interrogator when debugging for faster building speed
83
  if not self.debug_mode:
 
85
 
86
 
87
  def init_interrogator(self):
 
88
  config = Config()
89
  config.clip_model_name = self.args_base['clip_model_name']
90
  config.caption_model_name = self.args_base['caption_model_name']
 
92
  self.ci.config.chunk_size = 2048 if self.ci.config.clip_model_name == "ViT-L-14/openai" else 1024
93
  self.ci.config.flavor_intermediate_count = 2048 if self.ci.config.clip_model_name == "ViT-L-14/openai" else 1024
94
 
95
+
96
+ def _preload_pipeline(self):
97
+ for model in BASE_MODEL.values():
98
+ pipe = StableDiffusionPipeline.from_pretrained(
99
+ model, torch_dtype=torch.float16
100
+ ).to(self.args_base['device'])
101
+ pipe = None
102
+
103
+
104
  def title(self):
105
  gr.HTML(
106
  """
107
  <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
108
  <div>
109
  <h1 >Diffusion Cocktail 🍸: Fused Generation from Diffusion Models</h1>
110
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center; margin: 20px; gap: 10px;">
111
+ <a class="flex-item" href="https://arxiv.org/abs/2312.08873" target="_blank">
112
  <img src="https://img.shields.io/badge/arXiv-paper-darkred.svg" alt="arXiv Paper">
113
+ </a>
114
  <a class="flex-item" href="https://MAPS-research.github.io/Ditail" target="_blank">
115
  <img src="https://img.shields.io/badge/Project_Page-Diffusion_Cocktail-yellow.svg" alt="Project Page">
116
  </a>