juancopi81 commited on
Commit
d410924
β€’
1 Parent(s): 5d86663

Add google colab

Browse files
Files changed (2) hide show
  1. app.py +40 -11
  2. utils.py +6 -0
app.py CHANGED
@@ -3,18 +3,27 @@ import gradio as gr
3
  import torch
4
  from torch import autocast
5
  from diffusers import DiffusionPipeline
6
- import streamlit as st
7
  from transformers import (
8
  pipeline,
9
  MBart50TokenizerFast,
10
  MBartForConditionalGeneration,
11
  )
12
 
 
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  device_dict = {"cuda": 0, "cpu": -1}
15
  context = autocast if device == "cuda" else nullcontext
16
  dtype = torch.float16 if device == "cuda" else torch.float32
17
 
 
 
 
 
 
 
 
 
18
  # Add language detection pipeline
19
  language_detection_model_ckpt = "papluca/xlm-roberta-base-language-detection"
20
  language_detection_pipeline = pipeline("text-classification",
@@ -27,16 +36,28 @@ trans_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-larg
27
 
28
  model_id = "CompVis/stable-diffusion-v1-4"
29
 
30
- pipe = DiffusionPipeline.from_pretrained(
31
- model_id,
32
- custom_pipeline="multilingual_stable_diffusion",
33
- use_auth_token=st.secrets["USER_TOKEN"],
34
- detection_pipeline=language_detection_pipeline,
35
- translation_model=trans_model,
36
- translation_tokenizer=trans_tokenizer,
37
- revision="fp16",
38
- torch_dtype=dtype,
39
- )
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  pipe = pipe.to(device)
42
 
@@ -228,6 +249,10 @@ with block as demo:
228
  <p style="margin-bottom: 10px; font-size: 94%">
229
  Stable Diffusion Pipeline that supports prompts in 50 different languages.
230
  </p>
 
 
 
 
231
  </div>
232
  """
233
  )
@@ -277,4 +302,8 @@ Despite how impressive being able to turn text into image is, beware to the fact
277
  </div>
278
  """
279
  )
 
 
 
 
280
  demo.queue(max_size=25).launch()
 
3
  import torch
4
  from torch import autocast
5
  from diffusers import DiffusionPipeline
 
6
  from transformers import (
7
  pipeline,
8
  MBart50TokenizerFast,
9
  MBartForConditionalGeneration,
10
  )
11
 
12
+ import utils
13
+
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  device_dict = {"cuda": 0, "cpu": -1}
16
  context = autocast if device == "cuda" else nullcontext
17
  dtype = torch.float16 if device == "cuda" else torch.float32
18
 
19
+ # Detect if code is running in Colab
20
+ is_colab = utils.is_google_colab()
21
+ colab_instruction = "" if is_colab else """
22
+ <p>You can skip the queue using Colab:
23
+ <a href="https://colab.research.google.com/drive/1nhXyddThldnxPfIYO2my_bYinlMUW30R?usp=sharing">
24
+ <img data-canonical-src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a></p>"""
25
+ device_print = "GPU πŸ”₯" if torch.cuda.is_available() else "CPU πŸ₯Ά"
26
+
27
  # Add language detection pipeline
28
  language_detection_model_ckpt = "papluca/xlm-roberta-base-language-detection"
29
  language_detection_pipeline = pipeline("text-classification",
 
36
 
37
  model_id = "CompVis/stable-diffusion-v1-4"
38
 
39
+ if is_colab:
40
+ pipe = DiffusionPipeline.from_pretrained(
41
+ model_id,
42
+ custom_pipeline="multilingual_stable_diffusion",
43
+ detection_pipeline=language_detection_pipeline,
44
+ translation_model=trans_model,
45
+ translation_tokenizer=trans_tokenizer,
46
+ revision="fp16",
47
+ torch_dtype=dtype,
48
+ )
49
+ else:
50
+ import streamlit as st
51
+ pipe = DiffusionPipeline.from_pretrained(
52
+ model_id,
53
+ custom_pipeline="multilingual_stable_diffusion",
54
+ use_auth_token=st.secrets["USER_TOKEN"],
55
+ detection_pipeline=language_detection_pipeline,
56
+ translation_model=trans_model,
57
+ translation_tokenizer=trans_tokenizer,
58
+ revision="fp16",
59
+ torch_dtype=dtype,
60
+ )
61
 
62
  pipe = pipe.to(device)
63
 
 
249
  <p style="margin-bottom: 10px; font-size: 94%">
250
  Stable Diffusion Pipeline that supports prompts in 50 different languages.
251
  </p>
252
+ <p>
253
+ {colab_instruction}
254
+ Running on <b>{device_print}</b>{(" in a <b>Google Colab</b>." if is_colab else "")}
255
+ </p>
256
  </div>
257
  """
258
  )
 
302
  </div>
303
  """
304
  )
305
+ gr.Markdown('''
306
+ [![Twitter Follow](https://img.shields.io/twitter/follow/juancopi81?style=social)](https://twitter.com/juancopi81)
307
+ ![visitors](https://visitor-badge.glitch.me/badge?page_id=Juancopi81.MultilingualStableDiffusion)
308
+ ''')
309
  demo.queue(max_size=25).launch()
utils.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ def is_google_colab():
2
+ try:
3
+ import google.colab
4
+ return True
5
+ except:
6
+ return False